initial commit

This commit is contained in:
Peter Thorson
2011-09-07 20:27:29 -05:00
commit fa73c478e9
34 changed files with 3581 additions and 0 deletions

13
.gitignore vendored Normal file
View File

@@ -0,0 +1,13 @@
.DS_Store
#vim stuff
*~
*.swp
*.o
*.so
*.a
*.dylib
objs_shared/
objs_static/

169
Makefile Normal file
View File

@@ -0,0 +1,169 @@
objects = websocket_session.o websocket_server.o websocket_frame.o \
network_utilities.o sha1.o base64.o
OS=$(shell uname)
# Defaults
ifeq ($(OS), Darwin)
cxxflags_default = -c -O2 -DNDEBUG
else
cxxflags_default = -c -O2 -DNDEBUG
endif
cxxflags_small = -c
cxxflags_debug = -c -g
cxxflags_shared = -f$(PIC)
libname = libwebsocketpp
libname_hdr = websocketpp
libname_debug = $(libname)_dbg
suffix_shared = so
suffix_shared_darwin = dylib
suffix_static = a
major_version = 0
minor_version = 1.0
objdir = objs
# Variables
prefix ?= /usr/local
exec_prefix ?= $(prefix)
libdir ?= lib
includedir ?= include
srcdir ?= src
CXX ?= c++
AR ?= ar
PIC ?= PIC
BUILD_TYPE ?= "default"
SHARED ?= "1"
# Internal Variables
inst_path = $(exec_prefix)/$(libdir)
include_path = $(prefix)/$(includedir)
# BUILD_TYPE specific settings
ifeq ($(BUILD_TYPE), debug)
CXXFLAGS = $(cxxflags_debug)
libname := $(libname_debug)
else
CXXFLAGS ?= $(cxxflags_default)
endif
# SHARED specific settings
ifeq ($(SHARED), 1)
ifeq ($(OS), Darwin)
libname_shared = $(libname).$(suffix_shared_darwin)
else
libname_shared = $(libname).$(suffix_shared)
endif
libname_shared_major_version = $(libname_shared).$(major_version)
lib_target = $(libname_shared_major_version).$(minor_version)
objdir := $(objdir)_shared
CXXFLAGS := $(CXXFLAGS) $(cxxflags_shared)
else
lib_target = $(libname).$(suffix_static)
objdir := $(objdir)_static
endif
# Phony targets
.PHONY: all banner installdirs install install_headers clean uninstall \
uninstall_headers
# Targets
all: $(lib_target)
@echo "============================================================"
@echo "Done"
@echo "============================================================"
banner:
@echo "============================================================"
@echo "libwebsocketpp version: "$(major_version).$(minor_version) "target: "$(target) "OS: "$(OS)
@echo "============================================================"
installdirs: banner
mkdir -p $(objdir)
# Libraries
ifeq ($(SHARED),1)
$(lib_target): banner installdirs $(addprefix $(objdir)/, $(objects))
@echo "Link "
cd $(objdir) ; \
if test "$(OS)" = "Darwin" ; then \
$(CXX) -dynamiclib -lboost_system -Wl,-dylib_install_name -Wl,$(libname_shared_major_version) -o $@ $(objects) ; \
else \
$(CXX) -shared -lboost_system -Wl,-soname,$(libname_shared_major_version) -o $@ $(objects) ; \
fi ; \
mv -f $@ ../
@echo "Link: Done"
else
$(lib_target): banner installdirs $(addprefix $(objdir)/, $(objects))
@echo "Archive"
cd $(objdir) ; \
$(AR) -cvq $@ $(objects) ; \
mv -f $@ ../
@echo "Archive: Done"
endif
# Compile object files
$(objdir)/sha1.o: $(srcdir)/sha1/sha1.cpp
$(CXX) $< -o $@ $(CXXFLAGS)
$(objdir)/base64.o: $(srcdir)/base64/base64.cpp
$(CXX) $< -o $@ $(CXXFLAGS)
$(objdir)/%.o: $(srcdir)/%.cpp
$(CXX) $< -o $@ $(CXXFLAGS)
ifeq ($(SHARED),1)
install: banner install_headers $(lib_target)
@echo "Install shared library"
cp -f ./$(lib_target) $(inst_path)
cd $(inst_path) ; \
ln -sf $(lib_target) $(libname_shared_major_version) ; \
ln -sf $(libname_shared_major_version) $(libname_shared)
ifneq ($(OS),Darwin)
ldconfig
endif
@echo "Install shared library: Done."
else
install: banner install_headers $(lib_target)
@echo "Install static library"
cp -f ./$(lib_target) $(inst_path)
@echo "Install static library: Done."
endif
install_headers: banner
@echo "Install header files"
mkdir -p $(include_path)/$(libname_hdr)
# cp -f ./*.hpp $(include_path)/$(libname_hdr)
cp -f ./$(srcdir)/*.hpp $(include_path)/$(libname_hdr)
mkdir -p $(include_path)/$(libname_hdr)/base64
cp -f ./$(srcdir)/base64/base64.h $(include_path)/$(libname_hdr)/base64
mkdir -p $(include_path)/$(libname_hdr)/sha1
cp -f ./$(srcdir)/sha1/sha1.h $(include_path)/$(libname_hdr)/sha1
chmod -R a+r $(include_path)/$(libname_hdr)
find $(include_path)/$(libname_hdr) -type d -exec chmod a+x {} \;
@echo "Install header files: Done."
clean: banner
@echo "Clean library and object folder"
rm -rf $(objdir)
rm -f $(lib_target)
@echo "Clean library and object folder: Done"
ifeq ($(SHARED),1)
uninstall: banner uninstall_headers
@echo "Uninstall shared library"
rm -f $(inst_path)/$(libname_shared)
rm -f $(inst_path)/$(libname_shared_major_version)
rm -f $(inst_path)/$(lib_target)
ldconfig
@echo "Uninstall shared library: Done"
else
uninstall: banner uninstall_headers
@echo "Uninstall static library"
rm -f $(inst_path)/$(lib_target)
@echo "Uninstall static library: Done"
endif
uninstall_headers: banner
@echo "Uninstall header files"
rm -rf $(include_path)/$(libname)
@echo "Uninstall header files: Done"

View File

View File

@@ -0,0 +1,30 @@
#include "chat.hpp"
using websocketchat::chat_handler;
bool chat_handler::validate(websocketpp::session_ptr client) {
// We only know about the chat resource
if (client->get_request() != "/chat") {
client->set_http_error(404);
return false;
}
// Require specific origin example
if (client->get_header("Sec-WebSocket-Origin") != "http://zaphoyd.com") {
client->set_http_error(403);
return false;
}
return true;
}
void connect(session_ptr client) {
std::cout << "client " << client << " joined the lobby." << std::endl;
m_connections.insert(client);
// send user list to all clients
// send signon message from server to all other clients.
}

View File

@@ -0,0 +1,85 @@
#ifndef CHAT_HPP
#define CHAT_HPP
// com.zaphoyd.websocketpp.chat protocol
//
// client messages:
// msg [UTF8 text]
//
// server messages:
// {"type":"msg","sender":"<sender>","value":"<msg>" }
// {"type":"participants","value":[<participant>,...]}
#include <boost/shared_ptr.hpp>
#include <websocketpp/websocket_connection_handler.hpp>
#include <set>
#include <map>
#include <string>
#include <vector>
namespace websocketchat {
class chat_handler : public websocketpp::connection_handler {
public:
chat_handler() {}
virtual ~chat_handler() {}
bool validate(websocketpp::session_ptr client);
// add new connection to the lobby
void connect(session_ptr client) {
std::cout << "client " << client << " joined the lobby." << std::endl;
m_connections.insert(client);
// send user list to all clients
// send signon message from server to all other clients.
}
// someone disconnected from the lobby, remove them
void disconnect(websocketpp::session_ptr client,const std::string &reason) {
std::set<session_ptr>::iterator it = m_connections.find(client);
if (it == m_connections.end()) {
// this client has already disconnected, we can ignore this.
// this happens during certain types of disconnect where there is a
// deliberate "soft" disconnection preceeding the "hard" socket read
// fail or disconnect ack message.
return;
}
std::cout << "client " << client << " left the lobby." << std::endl;
m_connections.erase(it);
// send signoff message from server to all clients
// send user list to all remaining clients
}
void message(websocketpp::session_ptr client,const std::string &msg) {
std::cout << "message from client " << client << ": " << msg << std::endl;
// create JSON message to send based on msg
for (std::set<session_ptr>::iterator it = m_connections.begin();
it != m_connections.end(); it++) {
(*it)->send(msg);
}
}
// lobby will ignore binary messages
void message(websocketpp::session_ptr client,
const std::vector<unsigned char> &data) {}
private:
std::string serialize_state();
// list of outstanding connections
std::set<session_ptr> m_connections;
};
typedef boost::shared_ptr<chat_handler> chat_handler_ptr;
}
#endif // CHAT_HPP

View File

@@ -0,0 +1,41 @@
#include "chat.hpp"
#include <websocketpp/websocket_server.hpp>
#include <boost/asio.hpp>
#include <iostream>
using boost::asio::ip::tcp;
int main(int argc, char* argv[]) {
std::string host = "localhost:5000";
short port = 5000;
if (argc == 3) {
port = atoi(argv[2]);
std::stringstream temp;
temp << argv[1] << ":" << port;
host = temp.str();
}
websocketchat::chat_handler_ptr chat_handler(new websocketchat::chat_handler());
try {
boost::asio::io_service io_service;
tcp::endpoint endpoint(tcp::v6(), port);
websocketpp::server_ptr server(
new websocketpp::server(io_service,endpoint,host,chat_handler)
);
std::cout << "Starting chat server on " << host << std::endl;
io_service.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}

View File

@@ -0,0 +1,23 @@
CFLAGS = -O2
LDFLAGS =
CXX ?= c++
SHARED ?= "1"
ifeq ($(SHARED), 1)
LDFLAGS := $(LDFLAGS) -lboost_system -lwebsocketpp
else
LDFLAGS := $(LDFLAGS) -lboost_system ../../libwebsocketpp.a
endif
echo_server: echo_server.o echo.o
$(CXX) $(CFLAGS) $^ -o $@ $(LDFLAGS)
%.o: %.cpp
$(CXX) -c $(CFLAGS) -o $@ $^
# cleanup by removing generated files
#
.PHONY: clean
clean:
rm -f *.o echo_server

View File

@@ -0,0 +1,16 @@
#include "echo.hpp"
using websocketecho::echo_handler;
bool echo_handler::validate(websocketpp::session_ptr client) {
return true;
}
void echo_handler::message(websocketpp::session_ptr client, const std::string &msg) {
client->send(msg);
}
void echo_handler::message(websocketpp::session_ptr client,
const std::vector<unsigned char> &data) {
client->send(data);
}

View File

@@ -0,0 +1,35 @@
#ifndef ECHO_HANDLER_HPP
#define ECHO_HANDLER_HPP
#include "../../src/websocket_connection_handler.hpp"
#include <boost/shared_ptr.hpp>
#include <string>
#include <vector>
namespace websocketecho {
class echo_handler : public websocketpp::connection_handler {
public:
echo_handler() {}
virtual ~echo_handler() {}
// The echo server allows all domains is protocol free.
bool validate(websocketpp::session_ptr client);
// an echo server is stateless. The handler has no need to keep track of connected
// clients.
void connect(websocketpp::session_ptr client) {}
void disconnect(websocketpp::session_ptr client,const std::string &reason) {}
// both text and binary messages are echoed back to the sending client.
void message(websocketpp::session_ptr client,const std::string &msg);
void message(websocketpp::session_ptr client,
const std::vector<unsigned char> &data);
};
typedef boost::shared_ptr<echo_handler> echo_handler_ptr;
}
#endif // ECHO_HANDLER_HPP

View File

@@ -0,0 +1,94 @@
<!doctype html>
<html>
<head>
</head>
<body>
<script type="text/javascript">
var ws;
var url;
function connect() {
url = document.getElementById("server_url").value;
console.log(url);
if ("WebSocket" in window) {
ws = new WebSocket(url);
} else if ("MozWebSocket" in window) {
ws = new WebSocket(url);
} else {
document.getElementById("messages").innerHTML += "This Browser does not support WebSockets<br />";
return;
}
ws.onopen = function(e) {
document.getElementById("messages").innerHTML += "Client: A connection to "+ws.URL+" has been opened.<br />";
document.getElementById("server_url").disabled = true;
document.getElementById("toggle_connect").innerHTML = "Disconnect";
};
ws.onerror = function(e) {
document.getElementById("messages").innerHTML += "Client: An error occured, see console log for more details.<br />";
console.log(e);
};
ws.onclose = function(e) {
document.getElementById("messages").innerHTML += "Client: The connection to "+url+" was closed.<br />";
};
ws.onmessage = function(e) {
document.getElementById("messages").innerHTML += "Server: "+e.data+"<br />";
};
}
function disconnect() {
ws.close();
document.getElementById("server_url").disabled = false;
document.getElementById("toggle_connect").innerHTML = "Connect";
}
function toggle_connect() {
if (document.getElementById("server_url").disabled === false) {
connect();
} else {
disconnect();
}
}
function send() {
if (ws === undefined || ws.readyState != 1) {
document.getElementById("messages").innerHTML += "Client: Websocket is not avaliable for writing<br />";
return;
}
ws.send(document.getElementById("msg").value);
document.getElementById("msg").value = "";
}
</script>
<style>
body,html {
margin: 0px;
padding: 0px;
}
#controls {
float:right;
background-color: #999;
}
</style>
<div id="controls">
<div id="server">
<input type="text" name="server_url" id="server_url" value="ws://localhost:5000" />
<button id="toggle_connect" onclick="toggle_connect();">Connect</button>
</div>
<div id="message_input"><input type="text" name="msg" id="msg" value="Hello World!" />
<button onclick="send();">Send</button></div>
</div>
<div id="messages"></div>
</body>
</html>

BIN
examples/echo_server/echo_server Executable file

Binary file not shown.

View File

@@ -0,0 +1,43 @@
#include "echo.hpp"
#include "../../src/websocketpp.hpp"
#include <boost/asio.hpp>
#include <iostream>
using boost::asio::ip::tcp;
int main(int argc, char* argv[]) {
std::string host = "localhost:5000";
short port = 5000;
if (argc == 3) {
// TODO: input validation?
port = atoi(argv[2]);
std::stringstream temp;
temp << argv[1] << ":" << port;
host = temp.str();
}
websocketecho::echo_handler_ptr echo_handler(new websocketecho::echo_handler());
try {
boost::asio::io_service io_service;
tcp::endpoint endpoint(tcp::v6(), port);
websocketpp::server_ptr server(
new websocketpp::server(io_service,endpoint,host,echo_handler)
);
std::cout << "Starting echo server on " << host << std::endl;
io_service.run();
} catch (std::exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
}
return 0;
}

97
readme.txt Normal file
View File

@@ -0,0 +1,97 @@
How to use this library
WebSocket++ is a C++ websocket server library implimented using the Boost Asio networking stack. It is designed to provide a simple interface
Built on Asio's proactor asyncronious event loop.
Building a program using WebSocket++ has two parts
1. Impliment a connection handler.
This is done by subclassing websocketpp::connection_handler. Each websocket connection is attached to a connection handler. The handler impliments the following methods:
validate:
Called after the client handshake is recieved but before the connection is accepted. Allows cookie authentication, origin checking, subprotocol negotiation, etc
connect:
Called when the connection has been established and writes are allowed.
disconnect:
Called when the connection has been disconnected
message: text and binary variants
Called when a new websocket message is recieved.
The handler has access to the following websocket session api:
get_header
returns the value of an HTTP header sent by the client during the handshake
get_request
returns the resource requested by the client in the handshake
set_handler:
pass responsibility for this connection to another connection handler
set_http_error:
reject the connection with a specific HTTP error
add_header
adds an HTTP header to the server handshake
set_subprotocol
selects a subprotocol for the connection
send: text and binary varients
send a websocket message
ping:
send a ping
2. Start Asio's event loop with a TCP endpoint and your connection handler
There are two example programs in the examples directory that demonstrate this use pattern. One is a trivial stateless echo server, the other is a simple web based chat client. Both include example javascript clients. The echo server is suitable for use with automated testing suites such as the Autobahn test suite.
By default, a single connection handler object is used for all connections. If needs require, that default handler can either store per-connection state itself or create new handlers and pass off responsibility for the connection to them.
How to build this library
Build static library
make
Build and install in system include directories
make install
Avaliable flags:
- SHARED=1: build a shared instead of static library.
- DEBUG=1: build library with no optimizations, suitable for debugging. Debug library
is called libwebsocketpp_dbg
- CXX=*: uses * as the c++ compiler instead of system default
Build tested on
- Mac OS X 10.7 with apple gcc 4.2, macports gcc 4.6, apple llvm/clang
Outstanding issues
- Acknowledgement details
- Add license information to each file
- Add license.txt file
- Subprotocol negotiation interface
- check draft 13 issues
- make website
Unimplimented features
- SSL
- frame or streaming based api
- client features
- extension negotiation interface
Acknowledgements
- Boost Asio and other libraries
- base64 library
- sha1 library
- htonll discussion
- build/makefile from libjson
- Autobahn test suite
- testing by Keith Brisson

123
src/base64/base64.cpp Normal file
View File

@@ -0,0 +1,123 @@
/*
base64.cpp and base64.h
Copyright (C) 2004-2008 René Nyffenegger
This source code is provided 'as-is', without any express or implied
warranty. In no event will the author be held liable for any damages
arising from the use of this software.
Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:
1. The origin of this source code must not be misrepresented; you must not
claim that you wrote the original source code. If you use this source code
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.
2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original source code.
3. This notice may not be removed or altered from any source distribution.
René Nyffenegger rene.nyffenegger@adp-gmbh.ch
*/
#include "base64.h"
#include <iostream>
static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
static inline bool is_base64(unsigned char c) {
return (isalnum(c) || (c == '+') || (c == '/'));
}
std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) {
std::string ret;
int i = 0;
int j = 0;
unsigned char char_array_3[3];
unsigned char char_array_4[4];
while (in_len--) {
char_array_3[i++] = *(bytes_to_encode++);
if (i == 3) {
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for(i = 0; (i <4) ; i++)
ret += base64_chars[char_array_4[i]];
i = 0;
}
}
if (i)
{
for(j = i; j < 3; j++)
char_array_3[j] = '\0';
char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
char_array_4[3] = char_array_3[2] & 0x3f;
for (j = 0; (j < i + 1); j++)
ret += base64_chars[char_array_4[j]];
while((i++ < 3))
ret += '=';
}
return ret;
}
std::string base64_decode(std::string const& encoded_string) {
int in_len = encoded_string.size();
int i = 0;
int j = 0;
int in_ = 0;
unsigned char char_array_4[4], char_array_3[3];
std::string ret;
while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
char_array_4[i++] = encoded_string[in_]; in_++;
if (i ==4) {
for (i = 0; i <4; i++)
char_array_4[i] = base64_chars.find(char_array_4[i]);
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++)
ret += char_array_3[i];
i = 0;
}
}
if (i) {
for (j = i; j <4; j++)
char_array_4[j] = 0;
for (j = 0; j <4; j++)
char_array_4[j] = base64_chars.find(char_array_4[j]);
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
}
return ret;
}

4
src/base64/base64.h Normal file
View File

@@ -0,0 +1,4 @@
#include <string>
std::string base64_encode(unsigned char const* , unsigned int len);
std::string base64_decode(std::string const& s);

26
src/network_utilities.cpp Normal file
View File

@@ -0,0 +1,26 @@
#include "network_utilities.hpp"
uint64_t htonll(uint64_t src) {
static int typ = TYP_INIT;
unsigned char c;
union {
uint64_t ull;
unsigned char c[8];
} x;
if (typ == TYP_INIT) {
x.ull = 0x01;
typ = (x.c[7] == 0x01ULL) ? TYP_BIGE : TYP_SMLE;
}
if (typ == TYP_BIGE)
return src;
x.ull = src;
c = x.c[0]; x.c[0] = x.c[7]; x.c[7] = c;
c = x.c[1]; x.c[1] = x.c[6]; x.c[6] = c;
c = x.c[2]; x.c[2] = x.c[5]; x.c[5] = c;
c = x.c[3]; x.c[3] = x.c[4]; x.c[4] = c;
return x.ull;
}
uint64_t ntohll(uint64_t src) {
return htonll(src);
}

17
src/network_utilities.hpp Normal file
View File

@@ -0,0 +1,17 @@
#ifndef NETWORK_UTILITIES_HPP
#define NETWORK_UTILITIES_HPP
#include <stdint.h>
// http://www.viva64.com/en/k/0018/
// TODO: impliment stuff from here:
// http://stackoverflow.com/questions/809902/64-bit-ntohl-in-c
#define TYP_INIT 0
#define TYP_SMLE 1
#define TYP_BIGE 2
uint64_t htonll(uint64_t src);
uint64_t ntohll(uint64_t src);
#endif // NETWORK_UTILITIES_HPP

41
src/sha1/Makefile Executable file
View File

@@ -0,0 +1,41 @@
#
# Makefile
#
# Copyright (C) 1998, 2009
# Paul E. Jones <paulej@packetizer.com>
# All Rights Reserved.
#
#############################################################################
# $Id: Makefile 12 2009-06-22 19:34:25Z paulej $
#############################################################################
#
# Description:
# This is a makefile for UNIX to build the programs sha, shacmp, and
# shatest
#
#
CC = g++
CFLAGS = -c -O2 -Wall -D_FILE_OFFSET_BITS=64
LIBS =
OBJS = sha1.o
all: sha shacmp shatest
sha: sha.o $(OBJS)
$(CC) -o $@ sha.o $(OBJS) $(LIBS)
shacmp: shacmp.o $(OBJS)
$(CC) -o $@ shacmp.o $(OBJS) $(LIBS)
shatest: shatest.o $(OBJS)
$(CC) -o $@ shatest.o $(OBJS) $(LIBS)
%.o: %.cpp
$(CC) $(CFLAGS) -o $@ $<
clean:
$(RM) *.o sha shacmp shatest

48
src/sha1/Makefile.nt Executable file
View File

@@ -0,0 +1,48 @@
#
# Makefile.nt
#
# Copyright (C) 1998, 2009
# Paul E. Jones <paulej@packetizer.com>
# All Rights Reserved.
#
#############################################################################
# $Id: Makefile.nt 13 2009-06-22 20:20:32Z paulej $
#############################################################################
#
# Description:
# This is a makefile for Win32 to build the programs sha, shacmp, and
# shatest
#
# Portability Issues:
# Designed to work with Visual C++
#
#
.silent:
!include <win32.mak>
RM = del /q
LIBS = $(conlibs) setargv.obj
CFLAGS = -D _CRT_SECURE_NO_WARNINGS /EHsc /O2 /W3
OBJS = sha1.obj
all: sha.exe shacmp.exe shatest.exe
sha.exe: sha.obj $(OBJS)
$(link) $(conflags) -out:$@ sha.obj $(OBJS) $(LIBS)
shacmp.exe: shacmp.obj $(OBJS)
$(link) $(conflags) -out:$@ shacmp.obj $(OBJS) $(LIBS)
shatest.exe: shatest.obj $(OBJS)
$(link) $(conflags) -out:$@ shatest.obj $(OBJS) $(LIBS)
.cpp.obj:
$(cc) $(CFLAGS) $(cflags) $(cvars) $<
clean:
$(RM) *.obj sha.exe shacmp.exe shatest.exe

14
src/sha1/license.txt Executable file
View File

@@ -0,0 +1,14 @@
Copyright (C) 1998, 2009
Paul E. Jones <paulej@packetizer.com>
Freeware Public License (FPL)
This software is licensed as "freeware." Permission to distribute
this software in source and binary forms, including incorporation
into other products, is hereby granted without a fee. THIS SOFTWARE
IS PROVIDED 'AS IS' AND WITHOUT ANY EXPRESSED OR IMPLIED WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS FOR A PARTICULAR PURPOSE. THE AUTHOR SHALL NOT BE HELD
LIABLE FOR ANY DAMAGES RESULTING FROM THE USE OF THIS SOFTWARE, EITHER
DIRECTLY OR INDIRECTLY, INCLUDING, BUT NOT LIMITED TO, LOSS OF DATA
OR DATA BEING RENDERED INACCURATE.

176
src/sha1/sha.cpp Executable file
View File

@@ -0,0 +1,176 @@
/*
* sha.cpp
*
* Copyright (C) 1998, 2009
* Paul E. Jones <paulej@packetizer.com>
* All Rights Reserved
*
*****************************************************************************
* $Id: sha.cpp 13 2009-06-22 20:20:32Z paulej $
*****************************************************************************
*
* Description:
* This utility will display the message digest (fingerprint) for
* the specified file(s).
*
* Portability Issues:
* None.
*/
#include <stdio.h>
#include <string.h>
#ifdef WIN32
#include <io.h>
#endif
#include <fcntl.h>
#include "sha1.h"
/*
* Function prototype
*/
void usage();
/*
* main
*
* Description:
* This is the entry point for the program
*
* Parameters:
* argc: [in]
* This is the count of arguments in the argv array
* argv: [in]
* This is an array of filenames for which to compute message digests
*
* Returns:
* Nothing.
*
* Comments:
*
*/
int main(int argc, char *argv[])
{
SHA1 sha; // SHA-1 class
FILE *fp; // File pointer for reading files
char c; // Character read from file
unsigned message_digest[5]; // Message digest from "sha"
int i; // Counter
bool reading_stdin; // Are we reading standard in?
bool read_stdin = false; // Have we read stdin?
/*
* Check the program arguments and print usage information if -?
* or --help is passed as the first argument.
*/
if (argc > 1 && (!strcmp(argv[1],"-?") || !strcmp(argv[1],"--help")))
{
usage();
return 1;
}
/*
* For each filename passed in on the command line, calculate the
* SHA-1 value and display it.
*/
for(i = 0; i < argc; i++)
{
/*
* We start the counter at 0 to guarantee entry into the for loop.
* So if 'i' is zero, we will increment it now. If there is no
* argv[1], we will use STDIN below.
*/
if (i == 0)
{
i++;
}
if (argc == 1 || !strcmp(argv[i],"-"))
{
#ifdef WIN32
_setmode(_fileno(stdin), _O_BINARY);
#endif
fp = stdin;
reading_stdin = true;
}
else
{
if (!(fp = fopen(argv[i],"rb")))
{
fprintf(stderr, "sha: unable to open file %s\n", argv[i]);
return 2;
}
reading_stdin = false;
}
/*
* We do not want to read STDIN multiple times
*/
if (reading_stdin)
{
if (read_stdin)
{
continue;
}
read_stdin = true;
}
/*
* Reset the SHA1 object and process input
*/
sha.Reset();
c = fgetc(fp);
while(!feof(fp))
{
sha.Input(c);
c = fgetc(fp);
}
if (!reading_stdin)
{
fclose(fp);
}
if (!sha.Result(message_digest))
{
fprintf(stderr,"sha: could not compute message digest for %s\n",
reading_stdin?"STDIN":argv[i]);
}
else
{
printf( "%08X %08X %08X %08X %08X - %s\n",
message_digest[0],
message_digest[1],
message_digest[2],
message_digest[3],
message_digest[4],
reading_stdin?"STDIN":argv[i]);
}
}
return 0;
}
/*
* usage
*
* Description:
* This function will display program usage information to the user.
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void usage()
{
printf("usage: sha <file> [<file> ...]\n");
printf("\tThis program will display the message digest (fingerprint)\n");
printf("\tfor files using the Secure Hashing Algorithm (SHA-1).\n");
}

589
src/sha1/sha1.cpp Executable file
View File

@@ -0,0 +1,589 @@
/*
* sha1.cpp
*
* Copyright (C) 1998, 2009
* Paul E. Jones <paulej@packetizer.com>
* All Rights Reserved.
*
*****************************************************************************
* $Id: sha1.cpp 12 2009-06-22 19:34:25Z paulej $
*****************************************************************************
*
* Description:
* This class implements the Secure Hashing Standard as defined
* in FIPS PUB 180-1 published April 17, 1995.
*
* The Secure Hashing Standard, which uses the Secure Hashing
* Algorithm (SHA), produces a 160-bit message digest for a
* given data stream. In theory, it is highly improbable that
* two messages will produce the same message digest. Therefore,
* this algorithm can serve as a means of providing a "fingerprint"
* for a message.
*
* Portability Issues:
* SHA-1 is defined in terms of 32-bit "words". This code was
* written with the expectation that the processor has at least
* a 32-bit machine word size. If the machine word size is larger,
* the code should still function properly. One caveat to that
* is that the input functions taking characters and character arrays
* assume that only 8 bits of information are stored in each character.
*
* Caveats:
* SHA-1 is designed to work with messages less than 2^64 bits long.
* Although SHA-1 allows a message digest to be generated for
* messages of any number of bits less than 2^64, this implementation
* only works with messages with a length that is a multiple of 8
* bits.
*
*/
#include "sha1.h"
/*
* SHA1
*
* Description:
* This is the constructor for the sha1 class.
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
SHA1::SHA1()
{
Reset();
}
/*
* ~SHA1
*
* Description:
* This is the destructor for the sha1 class
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
SHA1::~SHA1()
{
// The destructor does nothing
}
/*
* Reset
*
* Description:
* This function will initialize the sha1 class member variables
* in preparation for computing a new message digest.
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void SHA1::Reset()
{
Length_Low = 0;
Length_High = 0;
Message_Block_Index = 0;
H[0] = 0x67452301;
H[1] = 0xEFCDAB89;
H[2] = 0x98BADCFE;
H[3] = 0x10325476;
H[4] = 0xC3D2E1F0;
Computed = false;
Corrupted = false;
}
/*
* Result
*
* Description:
* This function will return the 160-bit message digest into the
* array provided.
*
* Parameters:
* message_digest_array: [out]
* This is an array of five unsigned integers which will be filled
* with the message digest that has been computed.
*
* Returns:
* True if successful, false if it failed.
*
* Comments:
*
*/
bool SHA1::Result(unsigned *message_digest_array)
{
int i; // Counter
if (Corrupted)
{
return false;
}
if (!Computed)
{
PadMessage();
Computed = true;
}
for(i = 0; i < 5; i++)
{
message_digest_array[i] = H[i];
}
return true;
}
/*
* Input
*
* Description:
* This function accepts an array of octets as the next portion of
* the message.
*
* Parameters:
* message_array: [in]
* An array of characters representing the next portion of the
* message.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void SHA1::Input( const unsigned char *message_array,
unsigned length)
{
if (!length)
{
return;
}
if (Computed || Corrupted)
{
Corrupted = true;
return;
}
while(length-- && !Corrupted)
{
Message_Block[Message_Block_Index++] = (*message_array & 0xFF);
Length_Low += 8;
Length_Low &= 0xFFFFFFFF; // Force it to 32 bits
if (Length_Low == 0)
{
Length_High++;
Length_High &= 0xFFFFFFFF; // Force it to 32 bits
if (Length_High == 0)
{
Corrupted = true; // Message is too long
}
}
if (Message_Block_Index == 64)
{
ProcessMessageBlock();
}
message_array++;
}
}
/*
* Input
*
* Description:
* This function accepts an array of octets as the next portion of
* the message.
*
* Parameters:
* message_array: [in]
* An array of characters representing the next portion of the
* message.
* length: [in]
* The length of the message_array
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void SHA1::Input( const char *message_array,
unsigned length)
{
Input((unsigned char *) message_array, length);
}
/*
* Input
*
* Description:
* This function accepts a single octets as the next message element.
*
* Parameters:
* message_element: [in]
* The next octet in the message.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void SHA1::Input(unsigned char message_element)
{
Input(&message_element, 1);
}
/*
* Input
*
* Description:
* This function accepts a single octet as the next message element.
*
* Parameters:
* message_element: [in]
* The next octet in the message.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void SHA1::Input(char message_element)
{
Input((unsigned char *) &message_element, 1);
}
/*
* operator<<
*
* Description:
* This operator makes it convenient to provide character strings to
* the SHA1 object for processing.
*
* Parameters:
* message_array: [in]
* The character array to take as input.
*
* Returns:
* A reference to the SHA1 object.
*
* Comments:
* Each character is assumed to hold 8 bits of information.
*
*/
SHA1& SHA1::operator<<(const char *message_array)
{
const char *p = message_array;
while(*p)
{
Input(*p);
p++;
}
return *this;
}
/*
* operator<<
*
* Description:
* This operator makes it convenient to provide character strings to
* the SHA1 object for processing.
*
* Parameters:
* message_array: [in]
* The character array to take as input.
*
* Returns:
* A reference to the SHA1 object.
*
* Comments:
* Each character is assumed to hold 8 bits of information.
*
*/
SHA1& SHA1::operator<<(const unsigned char *message_array)
{
const unsigned char *p = message_array;
while(*p)
{
Input(*p);
p++;
}
return *this;
}
/*
* operator<<
*
* Description:
* This function provides the next octet in the message.
*
* Parameters:
* message_element: [in]
* The next octet in the message
*
* Returns:
* A reference to the SHA1 object.
*
* Comments:
* The character is assumed to hold 8 bits of information.
*
*/
SHA1& SHA1::operator<<(const char message_element)
{
Input((unsigned char *) &message_element, 1);
return *this;
}
/*
* operator<<
*
* Description:
* This function provides the next octet in the message.
*
* Parameters:
* message_element: [in]
* The next octet in the message
*
* Returns:
* A reference to the SHA1 object.
*
* Comments:
* The character is assumed to hold 8 bits of information.
*
*/
SHA1& SHA1::operator<<(const unsigned char message_element)
{
Input(&message_element, 1);
return *this;
}
/*
* ProcessMessageBlock
*
* Description:
* This function will process the next 512 bits of the message
* stored in the Message_Block array.
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
* Many of the variable names in this function, especially the single
* character names, were used because those were the names used
* in the publication.
*
*/
void SHA1::ProcessMessageBlock()
{
const unsigned K[] = { // Constants defined for SHA-1
0x5A827999,
0x6ED9EBA1,
0x8F1BBCDC,
0xCA62C1D6
};
int t; // Loop counter
unsigned temp; // Temporary word value
unsigned W[80]; // Word sequence
unsigned A, B, C, D, E; // Word buffers
/*
* Initialize the first 16 words in the array W
*/
for(t = 0; t < 16; t++)
{
W[t] = ((unsigned) Message_Block[t * 4]) << 24;
W[t] |= ((unsigned) Message_Block[t * 4 + 1]) << 16;
W[t] |= ((unsigned) Message_Block[t * 4 + 2]) << 8;
W[t] |= ((unsigned) Message_Block[t * 4 + 3]);
}
for(t = 16; t < 80; t++)
{
W[t] = CircularShift(1,W[t-3] ^ W[t-8] ^ W[t-14] ^ W[t-16]);
}
A = H[0];
B = H[1];
C = H[2];
D = H[3];
E = H[4];
for(t = 0; t < 20; t++)
{
temp = CircularShift(5,A) + ((B & C) | ((~B) & D)) + E + W[t] + K[0];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = CircularShift(30,B);
B = A;
A = temp;
}
for(t = 20; t < 40; t++)
{
temp = CircularShift(5,A) + (B ^ C ^ D) + E + W[t] + K[1];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = CircularShift(30,B);
B = A;
A = temp;
}
for(t = 40; t < 60; t++)
{
temp = CircularShift(5,A) +
((B & C) | (B & D) | (C & D)) + E + W[t] + K[2];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = CircularShift(30,B);
B = A;
A = temp;
}
for(t = 60; t < 80; t++)
{
temp = CircularShift(5,A) + (B ^ C ^ D) + E + W[t] + K[3];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = CircularShift(30,B);
B = A;
A = temp;
}
H[0] = (H[0] + A) & 0xFFFFFFFF;
H[1] = (H[1] + B) & 0xFFFFFFFF;
H[2] = (H[2] + C) & 0xFFFFFFFF;
H[3] = (H[3] + D) & 0xFFFFFFFF;
H[4] = (H[4] + E) & 0xFFFFFFFF;
Message_Block_Index = 0;
}
/*
* PadMessage
*
* Description:
* According to the standard, the message must be padded to an even
* 512 bits. The first padding bit must be a '1'. The last 64 bits
* represent the length of the original message. All bits in between
* should be 0. This function will pad the message according to those
* rules by filling the message_block array accordingly. It will also
* call ProcessMessageBlock() appropriately. When it returns, it
* can be assumed that the message digest has been computed.
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void SHA1::PadMessage()
{
/*
* Check to see if the current message block is too small to hold
* the initial padding bits and length. If so, we will pad the
* block, process it, and then continue padding into a second block.
*/
if (Message_Block_Index > 55)
{
Message_Block[Message_Block_Index++] = 0x80;
while(Message_Block_Index < 64)
{
Message_Block[Message_Block_Index++] = 0;
}
ProcessMessageBlock();
while(Message_Block_Index < 56)
{
Message_Block[Message_Block_Index++] = 0;
}
}
else
{
Message_Block[Message_Block_Index++] = 0x80;
while(Message_Block_Index < 56)
{
Message_Block[Message_Block_Index++] = 0;
}
}
/*
* Store the message length as the last 8 octets
*/
Message_Block[56] = (Length_High >> 24) & 0xFF;
Message_Block[57] = (Length_High >> 16) & 0xFF;
Message_Block[58] = (Length_High >> 8) & 0xFF;
Message_Block[59] = (Length_High) & 0xFF;
Message_Block[60] = (Length_Low >> 24) & 0xFF;
Message_Block[61] = (Length_Low >> 16) & 0xFF;
Message_Block[62] = (Length_Low >> 8) & 0xFF;
Message_Block[63] = (Length_Low) & 0xFF;
ProcessMessageBlock();
}
/*
* CircularShift
*
* Description:
* This member function will perform a circular shifting operation.
*
* Parameters:
* bits: [in]
* The number of bits to shift (1-31)
* word: [in]
* The value to shift (assumes a 32-bit integer)
*
* Returns:
* The shifted value.
*
* Comments:
*
*/
unsigned SHA1::CircularShift(int bits, unsigned word)
{
return ((word << bits) & 0xFFFFFFFF) | ((word & 0xFFFFFFFF) >> (32-bits));
}

89
src/sha1/sha1.h Executable file
View File

@@ -0,0 +1,89 @@
/*
* sha1.h
*
* Copyright (C) 1998, 2009
* Paul E. Jones <paulej@packetizer.com>
* All Rights Reserved.
*
*****************************************************************************
* $Id: sha1.h 12 2009-06-22 19:34:25Z paulej $
*****************************************************************************
*
* Description:
* This class implements the Secure Hashing Standard as defined
* in FIPS PUB 180-1 published April 17, 1995.
*
* Many of the variable names in this class, especially the single
* character names, were used because those were the names used
* in the publication.
*
* Please read the file sha1.cpp for more information.
*
*/
#ifndef _SHA1_H_
#define _SHA1_H_
class SHA1
{
public:
SHA1();
virtual ~SHA1();
/*
* Re-initialize the class
*/
void Reset();
/*
* Returns the message digest
*/
bool Result(unsigned *message_digest_array);
/*
* Provide input to SHA1
*/
void Input( const unsigned char *message_array,
unsigned length);
void Input( const char *message_array,
unsigned length);
void Input(unsigned char message_element);
void Input(char message_element);
SHA1& operator<<(const char *message_array);
SHA1& operator<<(const unsigned char *message_array);
SHA1& operator<<(const char message_element);
SHA1& operator<<(const unsigned char message_element);
private:
/*
* Process the next 512 bits of the message
*/
void ProcessMessageBlock();
/*
* Pads the current message block to 512 bits
*/
void PadMessage();
/*
* Performs a circular left shift operation
*/
inline unsigned CircularShift(int bits, unsigned word);
unsigned H[5]; // Message digest buffers
unsigned Length_Low; // Message length in bits
unsigned Length_High; // Message length in bits
unsigned char Message_Block[64]; // 512-bit message blocks
int Message_Block_Index; // Index into message block array
bool Computed; // Is the digest computed?
bool Corrupted; // Is the message digest corruped?
};
#endif

169
src/sha1/shacmp.cpp Executable file
View File

@@ -0,0 +1,169 @@
/*
* shacmp.cpp
*
* Copyright (C) 1998, 2009
* Paul E. Jones <paulej@packetizer.com>
* All Rights Reserved
*
*****************************************************************************
* $Id: shacmp.cpp 12 2009-06-22 19:34:25Z paulej $
*****************************************************************************
*
* Description:
* This utility will compare two files by producing a message digest
* for each file using the Secure Hashing Algorithm and comparing
* the message digests. This function will return 0 if they
* compare or 1 if they do not or if there is an error.
* Errors result in a return code higher than 1.
*
* Portability Issues:
* none.
*
*/
#include <stdio.h>
#include <string.h>
#include "sha1.h"
/*
* Return codes
*/
#define SHA1_COMPARE 0
#define SHA1_NO_COMPARE 1
#define SHA1_USAGE_ERROR 2
#define SHA1_FILE_ERROR 3
/*
* Function prototype
*/
void usage();
/*
* main
*
* Description:
* This is the entry point for the program
*
* Parameters:
* argc: [in]
* This is the count of arguments in the argv array
* argv: [in]
* This is an array of filenames for which to compute message digests
*
* Returns:
* Nothing.
*
* Comments:
*
*/
int main(int argc, char *argv[])
{
SHA1 sha; // SHA-1 class
FILE *fp; // File pointer for reading files
char c; // Character read from file
unsigned message_digest[2][5]; // Message digest for files
int i; // Counter
bool message_match; // Message digest match flag
int returncode;
/*
* If we have two arguments, we will assume they are filenames. If
* we do not have to arguments, call usage() and exit.
*/
if (argc != 3)
{
usage();
return SHA1_USAGE_ERROR;
}
/*
* Get the message digests for each file
*/
for(i = 1; i <= 2; i++)
{
sha.Reset();
if (!(fp = fopen(argv[i],"rb")))
{
fprintf(stderr, "sha: unable to open file %s\n", argv[i]);
return SHA1_FILE_ERROR;
}
c = fgetc(fp);
while(!feof(fp))
{
sha.Input(c);
c = fgetc(fp);
}
fclose(fp);
if (!sha.Result(message_digest[i-1]))
{
fprintf(stderr,"shacmp: could not compute message digest for %s\n",
argv[i]);
return SHA1_FILE_ERROR;
}
}
/*
* Compare the message digest values
*/
message_match = true;
for(i = 0; i < 5; i++)
{
if (message_digest[0][i] != message_digest[1][i])
{
message_match = false;
break;
}
}
if (message_match)
{
printf("Fingerprints match:\n");
returncode = SHA1_COMPARE;
}
else
{
printf("Fingerprints do not match:\n");
returncode = SHA1_NO_COMPARE;
}
printf( "\t%08X %08X %08X %08X %08X\n",
message_digest[0][0],
message_digest[0][1],
message_digest[0][2],
message_digest[0][3],
message_digest[0][4]);
printf( "\t%08X %08X %08X %08X %08X\n",
message_digest[1][0],
message_digest[1][1],
message_digest[1][2],
message_digest[1][3],
message_digest[1][4]);
return returncode;
}
/*
* usage
*
* Description:
* This function will display program usage information to the user.
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void usage()
{
printf("usage: shacmp <file> <file>\n");
printf("\tThis program will compare the message digests (fingerprints)\n");
printf("\tfor two files using the Secure Hashing Algorithm (SHA-1).\n");
}

149
src/sha1/shatest.cpp Executable file
View File

@@ -0,0 +1,149 @@
/*
* shatest.cpp
*
* Copyright (C) 1998, 2009
* Paul E. Jones <paulej@packetizer.com>
* All Rights Reserved
*
*****************************************************************************
* $Id: shatest.cpp 12 2009-06-22 19:34:25Z paulej $
*****************************************************************************
*
* Description:
* This file will exercise the SHA1 class and perform the three
* tests documented in FIPS PUB 180-1.
*
* Portability Issues:
* None.
*
*/
#include <iostream>
#include "sha1.h"
using namespace std;
/*
* Define patterns for testing
*/
#define TESTA "abc"
#define TESTB "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
/*
* Function prototype
*/
void DisplayMessageDigest(unsigned *message_digest);
/*
* main
*
* Description:
* This is the entry point for the program
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
int main()
{
SHA1 sha;
unsigned message_digest[5];
/*
* Perform test A
*/
cout << endl << "Test A: 'abc'" << endl;
sha.Reset();
sha << TESTA;
if (!sha.Result(message_digest))
{
cerr << "ERROR-- could not compute message digest" << endl;
}
else
{
DisplayMessageDigest(message_digest);
cout << "Should match:" << endl;
cout << '\t' << "A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D" << endl;
}
/*
* Perform test B
*/
cout << endl << "Test B: " << TESTB << endl;
sha.Reset();
sha << TESTB;
if (!sha.Result(message_digest))
{
cerr << "ERROR-- could not compute message digest" << endl;
}
else
{
DisplayMessageDigest(message_digest);
cout << "Should match:" << endl;
cout << '\t' << "84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1" << endl;
}
/*
* Perform test C
*/
cout << endl << "Test C: One million 'a' characters" << endl;
sha.Reset();
for(int i = 1; i <= 1000000; i++) sha.Input('a');
if (!sha.Result(message_digest))
{
cerr << "ERROR-- could not compute message digest" << endl;
}
else
{
DisplayMessageDigest(message_digest);
cout << "Should match:" << endl;
cout << '\t' << "34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F" << endl;
}
return 0;
}
/*
* DisplayMessageDigest
*
* Description:
* Display Message Digest array
*
* Parameters:
* None.
*
* Returns:
* Nothing.
*
* Comments:
*
*/
void DisplayMessageDigest(unsigned *message_digest)
{
ios::fmtflags flags;
cout << '\t';
flags = cout.setf(ios::hex|ios::uppercase,ios::basefield);
cout.setf(ios::uppercase);
for(int i = 0; i < 5 ; i++)
{
cout << message_digest[i] << ' ';
}
cout << endl;
cout.setf(flags);
}

View File

@@ -0,0 +1,61 @@
#ifndef WEBSOCKET_CONNECTION_HANDLER_HPP
#define WEBSOCKET_CONNECTION_HANDLER_HPP
#include <boost/shared_ptr.hpp>
#include <string>
#include <map>
namespace websocketpp {
class connection_handler;
typedef boost::shared_ptr<connection_handler> connection_handler_ptr;
}
#include "websocket_session.hpp"
namespace websocketpp {
class connection_handler {
public:
// validate will be called after a websocket handshake has been received and
// before it is accepted. It provides a handler the ability to refuse a
// connection based on application specific logic (ex: restrict domains or
// negotiate subprotocols)
//
// resource is the resource requested by the connection
// headers is a map containing the HTTP headers included in the client
// handshake as key/value strings
//
// if validate returns true the connection is allowed, otherwise the
// connection is closed.
virtual bool validate(session_ptr client) = 0;
// this will be called once the connected websocket is avaliable for
// writing messages. client may be a new websocket session or an existing
// session that was recently passed to this handler.
virtual void connect(session_ptr client) = 0;
// this will be called when the connected websocket is no longer avaliable
// for writing messages. This occurs under the following conditions:
// - Disconnect message recieved from the remote endpoint
// - Someone (usually this object) calls the disconnect method of session
// - A disconnect acknowledgement is recieved (in case another object
// calls the disconnect method of session
// - The connection handler assigned to this client was set to another
// handler
virtual void disconnect(session_ptr client,const std::string &reason) = 0;
// this will be called when a text message is recieved. Text will be
// encoded as UTF-8.
virtual void message(session_ptr client,const std::string &msg) = 0;
// this will be called when a binary message is recieved. Argument is a
// vector of the raw bytes in the message body.
virtual void message(session_ptr client,
const std::vector<unsigned char> &data) = 0;
};
}
#endif // WEBSOCKET_CONNECTION_HANDLER_HPP

340
src/websocket_frame.cpp Normal file
View File

@@ -0,0 +1,340 @@
#include "websocket_frame.hpp"
#include <iostream>
#include <algorithm>
#include <netinet/in.h>
using websocketpp::frame;
char* frame::get_header() {
return m_header;
}
char* frame::get_extended_header() {
return m_header+BASIC_HEADER_LENGTH;
}
unsigned int frame::get_header_len() const {
unsigned int temp = 2;
if (get_masked()) {
temp += 4;
}
if (get_basic_size() == 126) {
temp += 2;
} else if (get_basic_size() == 127) {
temp += 8;
}
return temp;
}
char* frame::get_masking_key() {
if (m_extended_header_bytes_needed > 0) {
throw "attempted to get masking_key before reading full header";
}
return m_masking_key;
}
// get and set header bits
bool frame::get_fin() const {
return ((m_header[0] & BPB0_FIN) == BPB0_FIN);
}
void frame::set_fin(bool fin) {
if (fin) {
m_header[0] |= BPB0_FIN;
} else {
m_header[0] &= (0xFF ^ BPB0_FIN);
}
}
// get and set reserved bits
bool frame::get_rsv1() const {
return ((m_header[0] & BPB0_RSV1) == BPB0_RSV1);
}
void frame::set_rsv1(bool b) {
if (b) {
m_header[0] |= BPB0_RSV1;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV1);
}
}
bool frame::get_rsv2() const {
return ((m_header[0] & BPB0_RSV2) == BPB0_RSV2);
}
void frame::set_rsv2(bool b) {
if (b) {
m_header[0] |= BPB0_RSV2;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV2);
}
}
bool frame::get_rsv3() const {
return ((m_header[0] & BPB0_RSV3) == BPB0_RSV3);
}
void frame::set_rsv3(bool b) {
if (b) {
m_header[0] |= BPB0_RSV3;
} else {
m_header[0] &= (0xFF ^ BPB0_RSV3);
}
}
frame::opcode frame::get_opcode() const {
return frame::opcode(m_header[0] & BPB0_OPCODE);
}
void frame::set_opcode(frame::opcode op) {
if (op > 0x0F) {
throw "invalid opcode";
}
if (get_basic_size() > BASIC_PAYLOAD_LIMIT &&
is_control()) {
throw "control frames can't have large payloads";
}
m_header[0] &= (0xFF ^ BPB0_OPCODE); // clear op bits
m_header[0] |= op; // set op bits
}
bool frame::get_masked() const {
return ((m_header[1] & BPB1_MASK) == BPB1_MASK);
}
void frame::set_masked(bool masked) {
if (masked) {
m_header[1] |= BPB1_MASK;
generate_masking_key();
} else {
m_header[1] &= (0xFF ^ BPB1_MASK);
clear_masking_key();
}
}
uint8_t frame::get_basic_size() const {
return m_header[1] & BPB1_PAYLOAD;
}
size_t frame::get_payload_size() const {
if (m_extended_header_bytes_needed > 0) {
// problem
throw "attempted to get payload size before reading full header";
}
return m_payload.size();
}
std::vector<unsigned char> &frame::get_payload() {
return m_payload;
}
void frame::set_payload(const std::vector<unsigned char> source) {
set_payload_helper(source.size());
std::copy(source.begin(),source.end(),m_payload.begin());
}
void frame::set_payload(const std::string source) {
set_payload_helper(source.size());
std::copy(source.begin(),source.end(),m_payload.begin());
}
bool frame::is_control() const {
return (get_opcode() > MAX_FRAME_OPCODE);
}
void frame::set_payload_helper(size_t s) {
if (s > max_payload_size) {
throw "requested payload is over implimentation defined limit";
}
// limits imposed by the websocket spec
if (s > BASIC_PAYLOAD_LIMIT &&
get_opcode() > MAX_FRAME_OPCODE) {
throw "control frames can't have large payloads";
}
if (s <= BASIC_PAYLOAD_LIMIT) {
m_header[1] = s;
} else if (s <= PAYLOAD_16BIT_LIMIT) {
m_header[1] = BASIC_PAYLOAD_16BIT_CODE;
// this reinterprets the second pair of bytes in m_header as a
// 16 bit int and writes the payload size there as an integer
// in network byte order
*reinterpret_cast<uint16_t*>(&m_header[BASIC_HEADER_LENGTH]) = htons(s);
} else if (s <= PAYLOAD_64BIT_LIMIT) {
m_header[1] = BASIC_PAYLOAD_64BIT_CODE;
*reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH]) = htonll(s);
} else {
throw "payload size limit is 63 bits";
}
m_payload.resize(s);
}
void frame::print_frame() const {
/*unsigned int len = get_header_len();
std::cout << "frame: ";
// print header
for (unsigned int i = 0; i < len; i++) {
std::cout << std::hex << (unsigned short)m_header[i] << " ";
}
// print message
for (auto &i: m_payload) {
std::cout << i;
}
std::cout << std::endl;*/
}
unsigned int frame::process_basic_header() {
m_extended_header_bytes_needed = 0;
m_payload.empty();
m_extended_header_bytes_needed = get_header_len() - BASIC_HEADER_LENGTH;
return m_extended_header_bytes_needed;
}
void frame::process_extended_header() {
m_extended_header_bytes_needed = 0;
uint8_t s = get_basic_size();
uint64_t payload_size;
int mask_index = BASIC_HEADER_LENGTH;
if (s <= BASIC_PAYLOAD_LIMIT) {
payload_size = s;
} else if (s == BASIC_PAYLOAD_16BIT_CODE) {
// reinterpret the second two bytes as a 16 bit integer in network
// byte order. Convert to host byte order and store locally.
payload_size = ntohs(*(
reinterpret_cast<uint16_t*>(&m_header[BASIC_HEADER_LENGTH])
));
mask_index += 2;
} else if (s == BASIC_PAYLOAD_64BIT_CODE) {
// reinterpret the second eight bytes as a 16 bit integer in
// network byte order. Convert to host byte order and store.
payload_size = ntohll(*(
reinterpret_cast<uint64_t*>(&m_header[BASIC_HEADER_LENGTH])
));
mask_index += 8;
} else {
// shouldn't be here
throw "invalid get_basic_size in process_extended_header";
}
if (payload_size < s) {
throw "payload size error";
}
if (get_masked() == 0) {
clear_masking_key();
} else {
// TODO: use this copy line (needs testing)
// std::copy(m_header[mask_index],m_header[mask_index+4],m_masking_key);
m_masking_key[0] = m_header[mask_index+0];
m_masking_key[1] = m_header[mask_index+1];
m_masking_key[2] = m_header[mask_index+2];
m_masking_key[3] = m_header[mask_index+3];
}
if (payload_size > max_payload_size) {
throw "got frame with payload greater than maximum frame buffer size.";
}
m_payload.resize(payload_size);
}
void frame::process_payload() {
// unmask payload one byte at a time
for (uint64_t i = 0; i < m_payload.size(); i++) {
m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]);
}
}
void frame::process_payload2() {
// unmask payload one byte at a time
//uint64_t key = (*((uint32_t*)m_masking_key;)) << 32;
//key += *((uint32_t*)m_masking_key);
// might need to switch byte order
uint32_t key = *((uint32_t*)m_masking_key);
// 4
uint64_t i = 0;
uint64_t s = (m_payload.size() / 4);
std::cout << "s: " << s << std::endl;
// chunks of 4
for (i = 0; i < s; i+=4) {
((uint32_t*)(&m_payload[0]))[i] = (((uint32_t*)(&m_payload[0]))[i] ^ key);
}
// finish the last few
for (i = s; i < m_payload.size(); i++) {
m_payload[i] = (m_payload[i] ^ m_masking_key[i%4]);
}
}
bool frame::validate_basic_header() const {
// check for control frame size
if (get_basic_size() > BASIC_PAYLOAD_LIMIT && is_control()) {
return false;
}
// check for reserved opcodes
if (get_rsv1() || get_rsv2() || get_rsv3()) {
return false;
}
// check for reserved opcodes
opcode op = get_opcode();
if (op > 0x02 && op < 0x08) {
return false;
}
if (op > 0x0A) {
return false;
}
// check for fragmented control message
if (is_control() && !get_fin()) {
return false;
}
return true;
}
void frame::generate_masking_key() {
throw "masking key generation not implimented";
/* TODO: test and tune
boost::random::random_device rng;
boost::random::uniform_int_distribution<> mask(0,UINT32_MAX);
*(reinterpret_cast<uint32_t *>(m_masking_key)) = mask(rng);
*/
}
void frame::clear_masking_key() {
m_masking_key[0] = 0;
m_masking_key[1] = 0;
m_masking_key[2] = 0;
m_masking_key[3] = 0;
}

113
src/websocket_frame.hpp Normal file
View File

@@ -0,0 +1,113 @@
#ifndef WEBSOCKET_FRAME_HPP
#define WEBSOCKET_FRAME_HPP
#include "network_utilities.hpp"
#include <string>
#include <vector>
#include <cstring>
namespace websocketpp {
class frame {
public:
enum opcode_s {
CONTINUATION_FRAME = 0x00,
TEXT_FRAME = 0x01,
BINARY_FRAME = 0x02,
CONNECTION_CLOSE = 0x08,
PING = 0x09,
PONG = 0x0A
};
typedef enum opcode_s opcode;
static const uint8_t MAX_FRAME_OPCODE = 0x07;
// basic payload byte flags
static const uint8_t BPB0_OPCODE = 0x0F;
static const uint8_t BPB0_RSV3 = 0x10;
static const uint8_t BPB0_RSV2 = 0x20;
static const uint8_t BPB0_RSV1 = 0x40;
static const uint8_t BPB0_FIN = 0x80;
static const uint8_t BPB1_PAYLOAD = 0x7F;
static const uint8_t BPB1_MASK = 0x80;
static const uint8_t BASIC_PAYLOAD_LIMIT = 0x7D; // 125
static const uint8_t BASIC_PAYLOAD_16BIT_CODE = 0x7E; // 126
static const uint16_t PAYLOAD_16BIT_LIMIT = 0xFFFF; // 2^16, 65535
static const uint8_t BASIC_PAYLOAD_64BIT_CODE = 0x7F; // 127
static const uint64_t PAYLOAD_64BIT_LIMIT = 0x7FFFFFFFFFFFFFFF; // 2^63
static const unsigned int BASIC_HEADER_LENGTH = 2;
static const unsigned int MAX_HEADER_LENGTH = 14;
static const uint8_t extended_header_length = 12;
static const uint64_t max_payload_size = 100000000; // 100MB
// create an empty frame for writing into
frame() {
// not sure if these are necessary with c++ but putting in just in case
memset(m_header,0,MAX_HEADER_LENGTH);
}
// get pointers to underlying buffers
char* get_header();
char* get_extended_header();
unsigned int get_header_len() const;
char* get_masking_key();
// get and set header bits
bool get_fin() const;
void set_fin(bool fin);
bool get_rsv1() const;
void set_rsv1(bool b);
bool get_rsv2() const;
void set_rsv2(bool b);
bool get_rsv3() const;
void set_rsv3(bool b);
opcode get_opcode() const;
void set_opcode(opcode op);
bool get_masked() const;
void set_masked(bool masked);
uint8_t get_basic_size() const;
size_t get_payload_size() const;
std::vector<unsigned char> &get_payload();
void set_payload(const std::vector<unsigned char> source);
void set_payload(const std::string source);
void set_payload_helper(size_t s);
bool is_control() const;
void print_frame() const;
// reads basic header, sets and returns m_header_bits_needed
unsigned int process_basic_header();
void process_extended_header();
void process_payload();
void process_payload2(); // experiment with more efficient masking code.
bool validate_basic_header() const;
void generate_masking_key();
void clear_masking_key();
private:
char m_header[MAX_HEADER_LENGTH];
std::vector<unsigned char> m_payload;
char m_masking_key[4];
unsigned int m_extended_header_bytes_needed;
};
}
#endif // WEBSOCKET_FRAME_HPP

44
src/websocket_server.cpp Normal file
View File

@@ -0,0 +1,44 @@
#include "websocket_server.hpp"
#include <boost/bind.hpp>
#include <iostream>
using websocketpp::server;
server::server(boost::asio::io_service& io_service,
const tcp::endpoint& endpoint,
const std::string& host,
connection_handler_ptr defc)
: m_host(host),
m_io_service(io_service),
m_acceptor(io_service, endpoint),
m_def_con_handler(defc) {
this->start_accept();
}
void server::start_accept() {
session_ptr new_ws(new session(m_io_service,m_host,m_def_con_handler));
m_acceptor.async_accept(
new_ws->socket(),
boost::bind(
&server::handle_accept,
this,
new_ws,
boost::asio::placeholders::error
)
);
}
void server::handle_accept(session_ptr session,
const boost::system::error_code& error) {
if (!error) {
session->start();
} else {
std::cout << "Error" << std::endl;
}
this->start_accept();
}

40
src/websocket_server.hpp Normal file
View File

@@ -0,0 +1,40 @@
#ifndef WEBSOCKET_SERVER_HPP
#define WEBSOCKET_SERVER_HPP
#include "websocket_session.hpp"
#include <boost/asio.hpp>
#include <boost/shared_ptr.hpp>
using boost::asio::ip::tcp;
namespace websocketpp {
class server {
public:
server(boost::asio::io_service& io_service,
const tcp::endpoint& endpoint,
const std::string& host,
connection_handler_ptr defc);
// creates a new session object and connects the next websocket
// connection to it.
void start_accept();
// if no errors starts the session's read loop and returns to the
// start_accept phase.
void handle_accept(session_ptr session,
const boost::system::error_code& error);
private:
std::string m_host;
boost::asio::io_service& m_io_service;
tcp::acceptor m_acceptor;
connection_handler_ptr m_def_con_handler;
};
typedef boost::shared_ptr<server> server_ptr;
}
#endif // WEBSOCKET_SERVER_HPP

666
src/websocket_session.cpp Normal file
View File

@@ -0,0 +1,666 @@
#include "websocket_session.hpp"
#include "websocket_frame.hpp"
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/algorithm/string.hpp>
#include <string>
#include <iostream>
#include <sstream>
using websocketpp::session;
void session::start() {
//std::cout << "[Connection " << this << "] WebSocket Connection request from " << m_socket.remote_endpoint() << std::endl;
// async read to handle_read_handshake
boost::asio::async_read_until(
m_socket,
m_buf,
"\r\n\r\n",
boost::bind(
&session::handle_read_handshake,
shared_from_this(),
boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred
)
);
}
void session::set_handler(connection_handler_ptr new_con) {
if (m_local_interface) {
m_local_interface->disconnect(shared_from_this(),"Setting new connection handler");
}
m_local_interface = new_con;
m_local_interface->connect(shared_from_this());
}
std::string session::get_header(const std::string& key) const {
std::map<std::string,std::string>::const_iterator h = m_headers.find(key);
if (h == m_headers.end()) {
return std::string();
} else {
return h->second;
}
}
std::string session::get_request() const {
return m_request;
}
void session::set_http_error(int code, std::string msg) {
m_http_error_code = code;
m_http_error_string = (msg != "" ? msg : lookup_http_error_string(code));
}
std::string session::lookup_http_error_string(int code) {
switch (code) {
case 400:
return "Bad Request";
case 401:
return "Unauthorized";
case 403:
return "Forbidden";
case 404:
return "Not Found";
case 405:
return "Method Not Allowed";
case 406:
return "Not Acceptable";
case 407:
return "Proxy Authentication Required";
case 408:
return "Request Timeout";
case 409:
return "Conflict";
case 410:
return "Gone";
case 411:
return "Length Required";
case 412:
return "Precondition Failed";
case 413:
return "Request Entity Too Large";
case 414:
return "Request-URI Too Long";
case 415:
return "Unsupported Media Type";
case 416:
return "Requested Range Not Satisfiable";
case 417:
return "Expectation Failed";
case 500:
return "Internal Server Error";
case 501:
return "Not Implimented";
case 502:
return "Bad Gateway";
case 503:
return "Service Unavailable";
case 504:
return "Gateway Timeout";
case 505:
return "HTTP Version Not Supported";
default:
return "Unknown";
}
}
void session::send(const std::string &msg) {
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::TEXT_FRAME);
m_write_frame.set_payload(msg);
write_frame();
}
// send binary frame
void session::send(const std::vector<unsigned char> &data) {
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::BINARY_FRAME);
m_write_frame.set_payload(data);
write_frame();
}
// send close frame
void session::disconnect(const std::string &reason) {
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::CONNECTION_CLOSE);
m_write_frame.set_payload(reason);
write_frame();
if (m_local_interface) {
m_local_interface->disconnect(shared_from_this(),reason);
}
}
void session::ping(const std::string &msg) {
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::PING);
m_write_frame.set_payload(msg);
write_frame();
}
void session::pong(const std::string &msg) {
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::PONG);
m_write_frame.set_payload(msg);
write_frame();
}
void session::handle_read_handshake(const boost::system::error_code& e,
std::size_t bytes_transferred) {
// read handshake and set local state (or pass to write_handshake)
std::ostringstream line;
line << &m_buf;
m_handshake += line.str();
//std::cout << "=== Raw Message ===" << std::endl;
//std::cout << m_handshake << std::endl;
//std::cout << "=== Raw Message end ===" << std::endl;
std::vector<std::string> tokens;
std::string::size_type start = 0;
std::string::size_type end;
// Get request and parse headers
end = m_handshake.find("\r\n",start);
while(end != std::string::npos) {
tokens.push_back(m_handshake.substr(start, end - start));
start = end + 2;
end = m_handshake.find("\r\n",start);
}
for (size_t i = 0; i < tokens.size(); i++) {
if (i == 0) {
m_request = tokens[i];
}
end = tokens[i].find(": ",0);
if (end != std::string::npos) {
m_headers[tokens[i].substr(0,end)] = tokens[i].substr(end+2);
}
}
// error checking
// check the method
if (m_request.substr(0,4) != "GET ") {
// invalid method
std::cout << "Websocket handshake has invalid method: " << m_request.substr(0,4) << ", killing connection." << std::endl;
// TODO: exception
this->set_http_error(400);
this->write_http_error();
return;
}
// check the HTTP version
// TODO: allow versions greater than 1.1
end = m_request.find(" HTTP/1.1",4);
if (end == std::string::npos) {
std::cout << "Websocket handshake has invalid HTTP version, killing connection." << std::endl;
this->set_http_error(400); // or error 505 HTTP Version Not Supported
this->write_http_error();
return;
}
m_request = m_request.substr(4,end-4);
// TODO: use exceptions or a helper function or something better here
// verify the presence of required headers
if (get_header("Host") != m_host) {
std::cerr << "Invalid or missing Host header." << std::endl;
this->set_http_error(400);
}
if (!boost::iequals(get_header("Upgrade"),"websocket")) {
std::cerr << "Invalid or missing Upgrade header." << std::endl;
this->set_http_error(400);
}
if (get_header("Connection").find("Upgrade") == std::string::npos) {
// TODO: case insensitive?
std::cerr << "Invalid or missing Connection header." << std::endl;
this->set_http_error(400);
}
if (get_header("Sec-WebSocket-Key") == "") {
std::cerr << "Invalid or missing Sec-Websocket-Key header." << std::endl;
this->set_http_error(400);
}
if (get_header("Sec-WebSocket-Version") != "8" &&
get_header("Sec-WebSocket-Version") != "7") {
std::cerr << "Invalid or missing Sec-Websocket-Version header." << std::endl;
this->set_http_error(400);
}
if (m_http_error_code != 0) {
this->write_http_error();
return;
}
// optional headers (delegated to the local interface)
if (m_local_interface && !m_local_interface->validate(shared_from_this())) {
std::cerr << "Local interface rejected the connection." << std::endl;
if (m_http_error_code == 0) {
this->set_http_error(400);
}
}
if (m_http_error_code != 0) {
this->write_http_error();
return;
}
this->write_handshake();
}
void session::write_handshake() {
std::string server_handshake = "";
std::string server_key = m_headers["Sec-WebSocket-Key"];
server_key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
SHA1 sha;
uint32_t message_digest[5];
sha.Reset();
sha << server_key.c_str();
if (!sha.Result(message_digest)) {
std::cerr << "Error computing sha1 hash, killing connection." << std::endl;
return;
}
// convert sha1 hash bytes to network byte order because this sha1
// library works on ints rather than bytes
for (int i = 0; i < 5; i++) {
message_digest[i] = htonl(message_digest[i]);
}
server_key = base64_encode(
reinterpret_cast<const unsigned char*>(message_digest),20);
server_handshake += "HTTP/1.1 101 Switching Protocols\r\n";
server_handshake += "Upgrade: websocket\r\n";
server_handshake += "Connection: Upgrade\r\n";
server_handshake += "Sec-WebSocket-Accept: "+server_key+"\r\n\r\n";
// TODO: handler requested headers
//std::cout << server_handshake << std::endl;
// start async write to handle_write_handshake
boost::asio::async_write(
m_socket,
boost::asio::buffer(server_handshake),
boost::bind(
&session::handle_write_handshake,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
void session::handle_write_handshake(const boost::system::error_code& error) {
if (error) {
handle_error("Error writing handshake",error);
return;
}
//std::cout << "WebSocket Version 8 connection opened." << std::endl;
m_status = OPEN;
if (m_local_interface) {
m_local_interface->connect(shared_from_this());
}
reset_message();
this->read_frame();
}
void session::write_http_error() {
std::stringstream server_handshake;
server_handshake << "HTTP/1.1 " << m_http_error_code << " "
<< m_http_error_string << "\r\n";
// additional headers?
server_handshake << "\r\n";
// start async write to handle_write_handshake
boost::asio::async_write(
m_socket,
boost::asio::buffer(server_handshake.str()),
boost::bind(
&session::handle_write_http_error,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
void session::handle_write_http_error(const boost::system::error_code& error) {
if (error) {
handle_error("Error writing http response",error);
return;
}
}
void session::read_frame() {
boost::asio::async_read(
m_socket,
boost::asio::buffer(m_read_frame.get_header(),
frame::BASIC_HEADER_LENGTH),
boost::bind(
&session::handle_frame_header,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
void session::handle_frame_header(const boost::system::error_code& error) {
if (error) {
handle_error("Error reading basic frame header",error);
return;
}
uint16_t extended_header_bytes = m_read_frame.process_basic_header();
if (!m_read_frame.validate_basic_header()) {
handle_error("Basic header validation failed",boost::system::error_code());
return;
}
if (extended_header_bytes == 0) {
read_payload();
} else {
boost::asio::async_read(
m_socket,
boost::asio::buffer(m_read_frame.get_extended_header(),
extended_header_bytes),
boost::bind(
&session::handle_extended_frame_header,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
}
void session::handle_extended_frame_header(
const boost::system::error_code& error) {
if (error) {
handle_error("Error reading extended frame header",error);
return;
}
// this sets up the buffer we are about to read into.
m_read_frame.process_extended_header();
this->read_payload();
}
void session::read_payload() {
/*char * foo = m_read_frame.get_header();
std::cout << std::hex << ((uint16_t*)foo)[0] << std::endl;
std::cout << "opcode: " << m_read_frame.get_opcode() << std::endl;
std::cout << "fin: " << m_read_frame.get_fin() << std::endl;
std::cout << "mask: " << m_read_frame.get_masked() << std::endl;
std::cout << "size: " << (uint16_t)m_read_frame.get_basic_size() << std::endl;
std::cout << "payload_size: " << m_read_frame.get_payload_size() << std::endl;*/
boost::asio::async_read(
m_socket,
boost::asio::buffer(m_read_frame.get_payload()),
boost::bind(
&session::handle_read_payload,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
void session::handle_read_payload (const boost::system::error_code& error) {
if (error) {
handle_error("Error reading payload data frame header",error);
return;
}
m_read_frame.process_payload();
switch (m_read_frame.get_opcode()) {
case frame::CONTINUATION_FRAME:
process_continuation();
break;
case frame::TEXT_FRAME:
process_text();
break;
case frame::BINARY_FRAME:
process_binary();
break;
case frame::CONNECTION_CLOSE:
process_close();
break;
case frame::PING:
process_ping();
break;
case frame::PONG:
process_pong();
break;
default:
// TODO: unknown frame type
// ignore or close?
break;
}
// check if there was an error processing this frame and fail the connection
if (m_error) {
return;
}
this->read_frame();
}
void session::handle_write_frame (const boost::system::error_code& error) {
if (error) {
handle_error("Error writing frame data",error);
}
//std::cout << "Successfully wrote frame." << std::endl;
}
void session::process_ping() {
std::cout << "Got ping" << std::endl;
// send pong
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::PONG);
m_write_frame.set_payload(m_read_frame.get_payload());
write_frame();
}
void session::process_pong() {
std::cout << "Got pong" << std::endl;
}
void session::process_text() {
// text is binary.
process_binary();
}
void session::process_binary() {
if (m_fragmented) {
handle_error("Got a new message before the previous was finished.",
boost::system::error_code());
return;
}
m_current_opcode = m_read_frame.get_opcode();
if (m_read_frame.get_fin()) {
deliver_message();
reset_message();
} else {
m_fragmented = true;
extract_payload();
}
}
void session::process_continuation() {
if (!m_fragmented) {
handle_error("Got a continuation frame without an outstanding message.",
boost::system::error_code());
return;
}
extract_payload();
// check if we are done
if (m_read_frame.get_fin()) {
deliver_message();
reset_message();
}
}
void session::process_close() {
if (m_status == OPEN) {
// send response and set to closed
std::string msg(m_read_frame.get_payload().begin(),
m_read_frame.get_payload().end());
std::cout << "Got connection close message, acking and closing the connection. Reason was: " << msg << std::endl;
m_status = CLOSED;
// send acknowledgement
m_write_frame.set_fin(true);
m_write_frame.set_opcode(frame::CONNECTION_CLOSE);
m_write_frame.set_payload("");
write_frame();
// let our local interface know that the remote client has
// disconnected
if (m_local_interface) {
m_local_interface->disconnect(shared_from_this(),msg);
}
} else if (m_status == CLOSING) {
// this is an ack of my close message
// close cleanly
std::cout << "Got ack for my close message, closing the connection" << std::endl;
m_status = CLOSED;
// let our local interface know that the remote client has
// disconnected and the reason (if any)
if (m_local_interface) {
m_local_interface->disconnect(shared_from_this(),"");
}
} else {
// ignore
}
}
void session::deliver_message() {
if (!m_local_interface) {
return;
}
if (m_current_opcode == frame::BINARY_FRAME) {
if (m_fragmented) {
m_local_interface->message(shared_from_this(),m_current_message);
} else {
m_local_interface->message(shared_from_this(),
m_read_frame.get_payload());
}
} else if (m_current_opcode == frame::TEXT_FRAME) {
std::string msg;
if (m_fragmented) {
msg.append(m_current_message.begin(),m_current_message.end());
} else {
msg.append(
m_read_frame.get_payload().begin(),
m_read_frame.get_payload().end()
);
}
m_local_interface->message(shared_from_this(),msg);
} else {
// fail
}
}
void session::extract_payload() {
std::vector<unsigned char> &msg = m_read_frame.get_payload();
m_current_message.resize(m_current_message.size()+msg.size());
std::copy(msg.begin(),msg.end(),m_current_message.end()-msg.size());
}
void session::write_frame() {
// print debug info
m_write_frame.print_frame();
std::vector<boost::asio::mutable_buffer> data;
data.push_back(
boost::asio::buffer(
m_write_frame.get_header(),
m_write_frame.get_header_len()
)
);
data.push_back(
boost::asio::buffer(m_write_frame.get_payload())
);
boost::asio::async_write(
m_socket,
data,
boost::bind(
&session::handle_write_frame,
shared_from_this(),
boost::asio::placeholders::error
)
);
}
void session::reset_message() {
m_error = false;
m_fragmented = false;
m_current_message.clear();
}
void session::handle_error(std::string msg,
const boost::system::error_code& error) {
std::stringstream e;
e << msg << " (" << error << ")";
std::cerr << e.str() << std::endl;
if (m_local_interface) {
m_local_interface->disconnect(shared_from_this(),e.str());
}
m_error = true;
}

209
src/websocket_session.hpp Normal file
View File

@@ -0,0 +1,209 @@
#ifndef WEBSOCKET_SESSION_HPP
#define WEBSOCKET_SESSION_HPP
#include <iostream>
#include <string>
#include <sstream>
#include <vector>
#include <map>
#include <algorithm>
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/enable_shared_from_this.hpp>
#include <arpa/inet.h>
namespace websocketpp {
class session;
typedef boost::shared_ptr<session> session_ptr;
}
#include "websocket_frame.hpp"
#include "websocket_connection_handler.hpp"
#include "base64/base64.h"
#include "sha1/sha1.h"
using boost::asio::ip::tcp;
namespace websocketpp {
class session : public boost::enable_shared_from_this<session> {
public:
enum ws_status {
CONNECTING,
OPEN,
CLOSING,
CLOSED
};
typedef enum ws_status status_code;
session (boost::asio::io_service& io_service,
const std::string& host,
connection_handler_ptr defc)
: m_host(host),
m_socket(io_service),
m_status(CONNECTING),
m_http_error_code(0),
m_local_interface(defc) {}
tcp::socket& socket() {
return m_socket;
}
// This function is called to begin the session loop. This method and all
// that come after it are called as a result of an async event completing.
// if any method in this chain returns before adding a new async event the
// session will end.
void start();
// sets the internal connection handler of this connection to new_con.
// This is useful if you want to switch handler objects during a connection
// Example: a generic lobby handler could validate the handshake negotiate a
// sub protocol to talk to and then pass the connection off to a handler for
// that sub protocol.
void set_handler(connection_handler_ptr new_con);
// Public handshake interface
// By default a failed handshake validation will return an HTTP 400 Bad
// Request error. If your application wants to reject the connection for
// another reason it can be set here. Example: 404 if the resource request
// is not recognized or 403 forbidden if the origin does not match. If only
// a code is supplied a msg will be generated automatically.
void set_http_error(int code,std::string msg = "");
// gets the value of a header or the empty string if not present.
std::string get_header(const std::string &key) const;
// adds an arbitrary header to the server handshake HTTP response.
void add_header(const std::string &key,const std::string &value);
std::string get_request() const;
// sets the subprotocol being used. This will result in the appropriate
// Sec-WebSocket-Protocol header being sent back to the client. The value
// here must have been present in the client's opening handshake.
void set_subprotocol(const std::string &protocol);
//int get_version();
//void add_extension();
// Public session interface
// send basic frame types
void send(const std::string &msg); // text
void send(const std::vector<unsigned char> &data); // binary
void ping(const std::string &msg);
void pong(const std::string &msg);
void disconnect(const std::string &reason);
private:
// handle_read_handshake reads the HTTP headers of the initial websocket
// handshake, parses out the request and headers, and does error checking
// TODO: Generalize a lot of the hard coded things in this method.
void handle_read_handshake(const boost::system::error_code& e,
std::size_t bytes_transferred);
// write_handshake calculates the server portion of the handshake and
// sends it back.
// TODO: Generalize this to include things like protocols, cookies, etc
void write_handshake();
// handle_write_handshake checks for errors writing the server handshake,
// officially declares a connection open, notifies the local interface,
// and starts the frame reading loop.
void handle_write_handshake(const boost::system::error_code& error);
// construct and write an HTTP error in the case the handshake goes poorly
void write_http_error();
void handle_write_http_error(const boost::system::error_code& error);
// start async read for a websocket frame (2 bytes) to handle_frame_header
void read_frame();
// reads frame header and devices if it needs to read more header or go
// straight to the payload.
void handle_frame_header(const boost::system::error_code& error);
// process extra headers and start payload read
void handle_extended_frame_header(const boost::system::error_code& error);
// initiate payload read
void read_payload();
// now the frame object should be complete. Process and send it on then
// reset for new frame
void handle_read_payload (const boost::system::error_code& error);
// checks for errors writing frames
void handle_write_frame (const boost::system::error_code& error);
// helper functions for processing each opcode
void process_ping();
void process_pong();
void process_text();
void process_binary();
void process_continuation();
void process_close();
// deliver message if we have a local interface attached
void deliver_message();
// copies the current read frame payload into the session so that the read
// frame can be cleared for the next read. This is done when fragmented
// messages are recieved.
void extract_payload();
// write m_write_frame out to the socket.
void write_frame();
// reset session for a new message
void reset_message();
// prints a diagnostic message and disconnects the local interface
void handle_error(std::string msg,const boost::system::error_code& error);
std::string lookup_http_error_string(int code);
private:
std::string m_host;
tcp::socket m_socket;
status_code m_status;
int m_http_error_code;
std::string m_http_error_string;
boost::asio::streambuf m_buf;
std::string m_handshake;
std::string m_request;
std::map<std::string,std::string> m_headers;
frame m_read_frame;
frame m_write_frame;
std::vector<unsigned char> m_current_message;
bool m_error;
bool m_fragmented;
frame::opcode m_current_opcode;
connection_handler_ptr m_local_interface;
};
}
#endif // WEBSOCKET_SESSION_HPP
// better debug printing system
// set acceptible origin and host headers
// case sensitive header values? e.g. websocket
// double check bugs in autobahn (sending wrong localhost:9000 header) not
// checking masking in the 9.x tests

6
src/websocketpp.hpp Normal file
View File

@@ -0,0 +1,6 @@
#ifndef WEBSOCKETPP_HPP
#define WEBSOCKETPP_HPP
#include <websocketpp/websocket_server.hpp>
#endif // WEBSOCKETPP_HPP

11
todo.txt Normal file
View File

@@ -0,0 +1,11 @@
// TODO: impliment stuff from here:
// http://stackoverflow.com/questions/809902/64-bit-ntohl-in-c
boost unit test suite?
figure out exception handling setup
nagle?
extensions