Replace websocketd/websocat with hpws. (#131)

This commit is contained in:
Ravin Perera
2020-10-15 17:02:06 +05:30
committed by GitHub
parent 5f40aebf08
commit 7183383ab7
22 changed files with 1061 additions and 661 deletions

View File

@@ -1,74 +1,36 @@
#include "comm_server.hpp"
#include "comm_client.hpp"
#include "comm_session.hpp"
#include "comm_session_handler.hpp"
#include "../hplog.hpp"
#include "../util.hpp"
#include "../bill/corebill.h"
#include "../hpws/hpws.hpp"
namespace comm
{
constexpr uint32_t DEFAULT_MAX_MSG_SIZE = 16 * 1024 * 1024;
int comm_server::start(
const uint16_t port, const char *domain_socket_name, const SESSION_TYPE session_type,
const bool is_binary, const bool use_size_header, const bool require_tls,
const uint64_t (&metric_thresholds)[4], const std::set<conf::ip_port_pair> &req_known_remotes, const uint64_t max_msg_size)
const uint16_t port, const SESSION_TYPE session_type, const uint64_t (&metric_thresholds)[4],
const std::set<conf::ip_port_pair> &req_known_remotes, const uint64_t max_msg_size)
{
int accept_fd = open_domain_socket(domain_socket_name);
if (accept_fd > 0)
{
watchdog_thread = std::thread(
&comm_server::connection_watchdog, this, accept_fd, session_type, is_binary,
std::ref(metric_thresholds), req_known_remotes, max_msg_size);
const uint64_t final_max_msg_size = max_msg_size > 0 ? max_msg_size : DEFAULT_MAX_MSG_SIZE;
inbound_message_processor_thread = std::thread(&comm_server::inbound_message_processor_loop, this, session_type);
return start_websocketd_process(port, domain_socket_name, is_binary,
use_size_header, require_tls, max_msg_size);
}
return -1;
}
int comm_server::open_domain_socket(const char *domain_socket_name)
{
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd == -1)
{
LOG_ERROR << errno << ": Domain socket open error";
if (start_hpws_server(port, final_max_msg_size) == -1)
return -1;
}
sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
watchdog_thread = std::thread(
&comm_server::connection_watchdog, this, session_type,
std::ref(metric_thresholds), req_known_remotes, final_max_msg_size);
strncpy(addr.sun_path, domain_socket_name, sizeof(addr.sun_path) - 1);
unlink(domain_socket_name);
inbound_message_processor_thread = std::thread(&comm_server::inbound_message_processor_loop, this, session_type);
if (bind(fd, (struct sockaddr *)&addr, sizeof(addr)) == -1)
{
LOG_ERROR << errno << ": Domain socket bind error";
return -1;
}
if (listen(fd, 5) == -1)
{
LOG_ERROR << errno << ": Domain socket listen error";
return -1;
}
// Set non-blocking behaviour.
// We do this so the accept() call returns immediately without blocking the listening thread.
int flags = fcntl(fd, F_GETFL);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
return fd; // This is the fd we should call accept() on.
return 0;
}
void comm_server::connection_watchdog(
const int accept_fd, const SESSION_TYPE session_type, const bool is_binary,
const uint64_t (&metric_thresholds)[4], const std::set<conf::ip_port_pair> &req_known_remotes, const uint64_t max_msg_size)
const SESSION_TYPE session_type, const uint64_t (&metric_thresholds)[4],
const std::set<conf::ip_port_pair> &req_known_remotes, const uint64_t max_msg_size)
{
util::mask_signal();
@@ -80,7 +42,7 @@ namespace comm
util::sleep(100);
// Accept any new incoming connection if available.
check_for_new_connection(sessions, accept_fd, session_type, is_binary, metric_thresholds, max_msg_size);
check_for_new_connection(sessions, session_type, metric_thresholds);
// Restore any missing outbound connections.
if (!req_known_remotes.empty())
@@ -88,100 +50,91 @@ namespace comm
if (loop_counter == 20)
{
loop_counter = 0;
maintain_known_connections(sessions, outbound_clients, req_known_remotes, session_type, is_binary, max_msg_size, metric_thresholds);
maintain_known_connections(sessions, req_known_remotes, session_type, max_msg_size, metric_thresholds);
}
loop_counter++;
}
// Cleanup any sessions that needs closure.
std::set<int> closed_session_fds;
for (auto &[fd, session] : sessions)
for (auto itr = sessions.begin(); itr != sessions.end();)
{
if (session.state == SESSION_STATE::MUST_CLOSE)
session.close(true);
if (itr->state == SESSION_STATE::MUST_CLOSE)
itr->close(true);
if (session.state == SESSION_STATE::CLOSED)
closed_session_fds.emplace(fd);
}
for (const int fd : closed_session_fds)
{
// Delete from sessions.
sessions.erase(fd);
// Delete from outbound clients.
const auto client_itr = outbound_clients.find(fd);
if (client_itr != outbound_clients.end())
{
client_itr->second.stop();
outbound_clients.erase(client_itr);
}
if (itr->state == SESSION_STATE::CLOSED)
itr = sessions.erase(itr);
else
++itr;
}
}
// If we reach this point that means we are shutting down.
// Close all sessions and clients
for (auto &[fd, session] : sessions)
// Close and erase all sessions.
for (comm_session &session : sessions)
session.close(false);
for (auto &[fd, client] : outbound_clients)
client.stop();
sessions.clear();
LOG_INFO << (session_type == SESSION_TYPE::USER ? "User" : "Peer") << " listener stopped.";
}
void comm_server::check_for_new_connection(
std::unordered_map<int, comm_session> &sessions, const int accept_fd,
const SESSION_TYPE session_type, const bool is_binary, const uint64_t (&metric_thresholds)[4],
const uint64_t max_msg_size)
std::list<comm_session> &sessions, const SESSION_TYPE session_type, const uint64_t (&metric_thresholds)[4])
{
// Accept new client connection (if available)
int client_fd = accept(accept_fd, NULL, NULL);
if (client_fd == -1 && errno != EAGAIN)
{
LOG_ERROR << errno << ": Domain socket accept error";
}
else if (client_fd > 0)
{
// New client connected.
const std::string ip = get_cgi_ip(client_fd);
if (!ip.empty())
{
if (corebill::is_banned(ip))
{
LOG_DEBUG << "Dropping connection for banned host " << ip;
close(client_fd);
}
else
{
comm_session session(ip, client_fd, client_fd, session_type, is_binary, true, metric_thresholds, max_msg_size);
if (session.on_connect() == 0)
{
std::scoped_lock<std::mutex> lock(sessions_mutex);
const auto [itr, success] = sessions.try_emplace(client_fd, std::move(session));
std::variant<hpws::client, hpws::error> accept_result = hpws_server.value().accept(true);
// Thread is seperately started after the moving operation to overcome the difficulty
// in accessing class member variables inside the thread.
// Class member variables gives unacceptable values if the thread starts before the move operation.
itr->second.start_messaging_threads();
}
}
if (std::holds_alternative<hpws::error>(accept_result))
{
const hpws::error error = std::get<hpws::error>(accept_result);
if (error.first == 199) // No client connected.
return;
LOG_ERROR << "Error in hpws accept():" << error.first << " " << error.second;
return;
}
// New client connected.
hpws::client client = std::move(std::get<hpws::client>(accept_result));
const std::variant<std::string, hpws::error> host_result = client.host_address();
if (std::holds_alternative<hpws::error>(host_result))
{
const hpws::error error = std::get<hpws::error>(host_result);
LOG_ERROR << "Error getting ip from hpws:" << error.first << " " << error.second;
}
else
{
const std::string &host_address = std::get<std::string>(host_result);
if (corebill::is_banned(host_address))
{
// We just let the client object gets destructed without adding it to a session.
LOG_DEBUG << "Dropping connection for banned host " << host_address;
}
else
{
close(client_fd);
LOG_ERROR << "Closed bad client socket: " << client_fd;
comm_session session(host_address, std::move(client), session_type, true, metric_thresholds);
if (session.on_connect() == 0)
{
std::scoped_lock<std::mutex> lock(sessions_mutex);
comm_session &inserted_session = sessions.emplace_back(std::move(session));
// Thread is seperately started after the moving operation to overcome the difficulty
// in accessing class member variables inside the thread.
// Class member variables gives unacceptable values if the thread starts before the move operation.
inserted_session.start_messaging_threads();
}
}
}
}
void comm_server::maintain_known_connections(
std::unordered_map<int, comm_session> &sessions, std::unordered_map<int, comm_client> &outbound_clients,
const std::set<conf::ip_port_pair> &req_known_remotes, const SESSION_TYPE session_type, const bool is_binary,
const uint64_t max_msg_size, const uint64_t (&metric_thresholds)[4])
std::list<comm_session> &sessions, const std::set<conf::ip_port_pair> &req_known_remotes,
const SESSION_TYPE session_type, const uint64_t max_msg_size, const uint64_t (&metric_thresholds)[4])
{
// Find already connected known remote parties list
std::set<conf::ip_port_pair> known_remotes;
for (const auto &[fd, session] : sessions)
for (const comm_session &session : sessions)
{
if (session.state != SESSION_STATE::CLOSED && !session.known_ipport.first.empty())
known_remotes.emplace(session.known_ipport);
@@ -200,28 +153,39 @@ namespace comm
const uint16_t port = ipport.second;
LOG_DEBUG << "Trying to connect " << host << ":" << std::to_string(port);
comm::comm_client client;
if (client.start(host, port, metric_thresholds, conf::cfg.peermaxsize) == -1)
std::variant<hpws::client, hpws::error> client_result = hpws::client::connect(conf::ctx.hpws_exe_path, max_msg_size, host, port, "/", {}, util::fork_detach);
if (std::holds_alternative<hpws::error>(client_result))
{
LOG_ERROR << "Outbound connection attempt failed: " << host << ":" << std::to_string(port);
const hpws::error error = std::get<hpws::error>(client_result);
LOG_ERROR << "Outbound connection hpws error:" << error.first << " " << error.second;
}
else
{
comm::comm_session session(host, client.read_fd, client.write_fd, comm::SESSION_TYPE::PEER, is_binary, false, metric_thresholds, max_msg_size);
session.known_ipport = ipport;
if (session.on_connect() == 0)
hpws::client client = std::move(std::get<hpws::client>(client_result));
const std::variant<std::string, hpws::error> host_result = client.host_address();
if (std::holds_alternative<hpws::error>(host_result))
{
const hpws::error error = std::get<hpws::error>(host_result);
LOG_ERROR << "Error getting ip from hpws:" << error.first << " " << error.second;
}
else
{
const std::string &host_address = std::get<std::string>(host_result);
comm::comm_session session(host_address, std::move(client), session_type, false, metric_thresholds);
session.known_ipport = ipport;
if (session.on_connect() == 0)
{
std::scoped_lock<std::mutex> lock(sessions_mutex);
const auto [itr, success] = sessions.try_emplace(client.read_fd, std::move(session));
comm_session &inserted_session = sessions.emplace_back(std::move(session));
// Thread is seperately started after the moving operation to overcome the difficulty
// in accessing class member variables inside the thread.
// Class member variables gives unacceptable values if the thread starts before the move operation.
itr->second.start_messaging_threads();
}
inserted_session.start_messaging_threads();
outbound_clients.emplace(client.read_fd, std::move(client));
known_remotes.emplace(ipport);
known_remotes.emplace(ipport);
}
}
}
}
@@ -238,7 +202,7 @@ namespace comm
{
// Process one message from each session in round-robin fashion.
std::scoped_lock<std::mutex> lock(sessions_mutex);
for (auto &[fd, session] : sessions)
for (comm_session &session : sessions)
{
const int result = session.process_next_inbound_message();
@@ -258,170 +222,38 @@ namespace comm
LOG_INFO << (session_type == SESSION_TYPE::USER ? "User" : "Peer") << " message processor stopped.";
}
int comm_server::start_websocketd_process(
const uint16_t port, const char *domain_socket_name, const bool is_binary,
const bool use_size_header, const bool require_tls, const uint64_t max_msg_size)
int comm_server::start_hpws_server(const uint16_t port, const uint64_t max_msg_size)
{
// setup pipe for firewall
int firewall_pipe[2]; // parent to child pipe
std::variant<hpws::server, hpws::error> result = hpws::server::create(
conf::ctx.hpws_exe_path,
max_msg_size,
port,
512, // Max connections
2, // Max connections per IP.
conf::ctx.tls_cert_file,
conf::ctx.tls_key_file,
{},
util::fork_detach);
if (pipe(firewall_pipe))
if (std::holds_alternative<hpws::error>(result))
{
LOG_ERROR << errno << ": pipe() call failed for firewall";
}
else
{
firewall_out = firewall_pipe[1];
}
const pid_t pid = fork();
if (pid > 0)
{
// HotPocket process.
// Close the child reading end of the pipe in the parent
if (firewall_out > 0)
close(firewall_pipe[0]);
// Wait for some time and check if websocketd is still running properly.
// Sending signal 0 to test whether process exist.
util::sleep(20);
if (util::kill_process(pid, false, 0) == -1)
return -1;
websocketd_pid = pid;
}
else if (pid == 0)
{
// Websocketd process.
util::fork_detach();
// We are using websocketd forked repo: https://github.com/codetsunami/websocketd
if (firewall_out > 0)
{
// Close parent writing end of the pipe in the child
close(firewall_pipe[1]);
// Override stdin in the child's file table
dup2(firewall_pipe[0], 0);
}
std::vector<std::string> args_vec;
args_vec.reserve(16);
const std::string max_msg_size_str = (max_msg_size > 0) ? std::to_string(max_msg_size) : "134217728"; //128MB
// Fill process args.
args_vec.push_back(conf::ctx.websocketd_exe_path);
if (require_tls)
{
args_vec.push_back("--ssl");
args_vec.push_back("--sslcert");
args_vec.push_back(conf::ctx.tls_cert_file);
args_vec.push_back("--sslkey");
args_vec.push_back(conf::ctx.tls_key_file);
}
args_vec.push_back("--port");
args_vec.push_back(std::to_string(port));
args_vec.push_back(is_binary ? "--binary=true" : "--binary=false");
args_vec.push_back(use_size_header ? "--sizeheader=true" : "--sizeheader=false");
args_vec.push_back(std::string("--maxframe=").append(max_msg_size_str));
args_vec.push_back("--loglevel=error");
args_vec.push_back("nc"); // netcat (OpenBSD) is used for domain socket redirection.
args_vec.push_back("-U"); // Use UNIX domain socket
args_vec.push_back(domain_socket_name);
char *execv_args[args_vec.size()];
int idx = 0;
for (std::string &arg : args_vec)
execv_args[idx++] = arg.data();
execv_args[idx] = NULL;
const int ret = execv(execv_args[0], execv_args);
std::cerr << errno << ": websocketd process execv failed.\n";
exit(1);
}
else
{
LOG_ERROR << errno << ": fork() failed when starting websocketd process.";
const hpws::error e = std::get<hpws::error>(result);
LOG_ERROR << "Error creating hpws server:" << e.first << " " << e.second;
return -1;
}
hpws_server.emplace(std::move(std::get<hpws::server>(result)));
return 0;
}
void comm_server::firewall_ban(std::string_view ip, const bool unban)
{
if (firewall_out < 0)
return;
iovec iov[]{
{(void *)(unban ? "r" : "a"), 1},
{(void *)ip.data(), ip.length()}};
writev(firewall_out, iov, 2);
}
/**
* If the fd supplied was produced by accept()ing unix domain socket connection
* the process at the other end is inspected for CGI environment variables
* and the REMOTE_ADDR variable is returned as std::string, otherwise empty string
*/
std::string comm_server::get_cgi_ip(const int fd)
{
socklen_t length;
ucred uc;
length = sizeof(struct ucred);
// Ask the operating system for information about the other process
if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &uc, &length) == -1)
{
LOG_ERROR << errno << ": Could not retrieve PID from unix domain socket";
return "";
}
// Open /proc/<pid>/environ for that process
std::stringstream ss;
ss << "/proc/" << uc.pid << "/environ";
std::string fn = ss.str();
const int envfd = open(fn.c_str(), O_RDONLY | O_CLOEXEC);
if (!envfd)
{
LOG_ERROR << errno << ": Could not open environ block for process on other end of unix domain socket PID=" << uc.pid;
return "";
}
// Read environ block
char envblock[0x7fff];
const ssize_t bytes_read = read(envfd, envblock, 0x7fff); //0x7fff bytes is an operating system size limit for this block
close(envfd);
// Find the REMOTE_ADDR entry. Envrion block delimited by \0
for (char *upto = envblock, *last = envblock; upto - envblock < bytes_read; ++upto)
{
if (*upto == '\0')
{
if (upto - last > 12 && strncmp(last, "REMOTE_ADDR=", 12) == 0)
return std::string((const char *)(last + 12));
last = upto + 1;
}
}
LOG_ERROR << "Could not find REMOTE_ADDR variable in /proc/" << uc.pid << "/environ";
return "";
}
void comm_server::stop()
{
should_stop_listening = true;
watchdog_thread.join();
inbound_message_processor_thread.join();
hpws_server.reset();
if (websocketd_pid > 0)
util::kill_process(websocketd_pid, false); // Kill websocketd.
inbound_message_processor_thread.join();
}
} // namespace comm