#include "comm_server.hpp" #include "comm_client.hpp" #include "comm_session.hpp" #include "comm_session_handler.hpp" #include "../hplog.hpp" #include "../util.hpp" #include "../bill/corebill.h" namespace comm { int comm_server::start( const uint16_t port, const char *domain_socket_name, const SESSION_TYPE session_type, const bool is_binary, const bool use_size_header, const uint64_t (&metric_thresholds)[4], const std::set &req_known_remotes, const uint64_t max_msg_size) { int accept_fd = open_domain_socket(domain_socket_name); if (accept_fd > 0) { watchdog_thread = std::thread( &comm_server::connection_watchdog, this, accept_fd, session_type, is_binary, std::ref(metric_thresholds), req_known_remotes, max_msg_size); inbound_message_processor_thread = std::thread(&comm_server::inbound_message_processor_loop, this, session_type); return start_websocketd_process(port, domain_socket_name, is_binary, use_size_header, max_msg_size); } return -1; } int comm_server::open_domain_socket(const char *domain_socket_name) { int fd = socket(AF_UNIX, SOCK_STREAM, 0); if (fd == -1) { LOG_ERROR << errno << ": Domain socket open error"; return -1; } sockaddr_un addr; memset(&addr, 0, sizeof(addr)); addr.sun_family = AF_UNIX; strncpy(addr.sun_path, domain_socket_name, sizeof(addr.sun_path) - 1); unlink(domain_socket_name); if (bind(fd, (struct sockaddr *)&addr, sizeof(addr)) == -1) { LOG_ERROR << errno << ": Domain socket bind error"; return -1; } if (listen(fd, 5) == -1) { LOG_ERROR << errno << ": Domain socket listen error"; return -1; } // Set non-blocking behaviour. // We do this so the accept() call returns immediately without blocking the listening thread. int flags = fcntl(fd, F_GETFL); fcntl(fd, F_SETFL, flags | O_NONBLOCK); return fd; // This is the fd we should call accept() on. } void comm_server::connection_watchdog( const int accept_fd, const SESSION_TYPE session_type, const bool is_binary, const uint64_t (&metric_thresholds)[4], const std::set &req_known_remotes, const uint64_t max_msg_size) { util::mask_signal(); // Counter to track when to initiate outbound client connections. int16_t loop_counter = -1; while (!should_stop_listening) { util::sleep(100); // Accept any new incoming connection if available. check_for_new_connection(sessions, accept_fd, session_type, is_binary, metric_thresholds, max_msg_size); // Restore any missing outbound connections. if (!req_known_remotes.empty()) { if (loop_counter == 20) { loop_counter = 0; maintain_known_connections(sessions, outbound_clients, req_known_remotes, session_type, is_binary, max_msg_size, metric_thresholds); } loop_counter++; } // Cleanup any sessions that needs closure. std::set closed_session_fds; for (auto &[fd, session] : sessions) { if (session.state == SESSION_STATE::MUST_CLOSE) session.close(true); if (session.state == SESSION_STATE::CLOSED) closed_session_fds.emplace(fd); } for (const int fd : closed_session_fds) { // Delete from sessions. sessions.erase(fd); // Delete from outbound clients. const auto client_itr = outbound_clients.find(fd); if (client_itr != outbound_clients.end()) { client_itr->second.stop(); outbound_clients.erase(client_itr); } } } // If we reach this point that means we are shutting down. // Close all sessions and clients for (auto &[fd, session] : sessions) session.close(false); for (auto &[fd, client] : outbound_clients) client.stop(); LOG_INFO << (session_type == SESSION_TYPE::USER ? "User" : "Peer") << " listener stopped."; } void comm_server::check_for_new_connection( std::unordered_map &sessions, const int accept_fd, const SESSION_TYPE session_type, const bool is_binary, const uint64_t (&metric_thresholds)[4], const uint64_t max_msg_size) { // Accept new client connection (if available) int client_fd = accept(accept_fd, NULL, NULL); if (client_fd == -1 && errno != EAGAIN) { LOG_ERROR << errno << ": Domain socket accept error"; } else if (client_fd > 0) { // New client connected. const std::string ip = get_cgi_ip(client_fd); if (!ip.empty()) { if (corebill::is_banned(ip)) { LOG_DEBUG << "Dropping connection for banned host " << ip; close(client_fd); } else { comm_session session(ip, client_fd, client_fd, session_type, is_binary, true, metric_thresholds, max_msg_size); if (session.on_connect() == 0) { std::scoped_lock lock(sessions_mutex); const auto [itr, success] = sessions.try_emplace(client_fd, std::move(session)); // Thread is seperately started after the moving operation to overcome the difficulty // in accessing class member variables inside the thread. // Class member variables gives unacceptable values if the thread starts before the move operation. itr->second.start_messaging_threads(); } } } else { close(client_fd); LOG_ERROR << "Closed bad client socket: " << client_fd; } } } void comm_server::maintain_known_connections( std::unordered_map &sessions, std::unordered_map &outbound_clients, const std::set &req_known_remotes, const SESSION_TYPE session_type, const bool is_binary, const uint64_t max_msg_size, const uint64_t (&metric_thresholds)[4]) { // Find already connected known remote parties list std::set known_remotes; for (const auto &[fd, session] : sessions) { if (session.state != SESSION_STATE::CLOSED && !session.known_ipport.first.empty()) known_remotes.emplace(session.known_ipport); } for (const auto &ipport : req_known_remotes) { if (should_stop_listening) break; // Check if we are already connected to this remote party. if (known_remotes.find(ipport) != known_remotes.end()) continue; std::string_view host = ipport.first; const uint16_t port = ipport.second; LOG_DEBUG << "Trying to connect " << host << ":" << std::to_string(port); comm::comm_client client; if (client.start(host, port, metric_thresholds, conf::cfg.peermaxsize) == -1) { LOG_ERROR << "Outbound connection attempt failed: " << host << ":" << std::to_string(port); } else { comm::comm_session session(host, client.read_fd, client.write_fd, comm::SESSION_TYPE::PEER, is_binary, false, metric_thresholds, max_msg_size); session.known_ipport = ipport; if (session.on_connect() == 0) { { std::scoped_lock lock(sessions_mutex); const auto [itr, success] = sessions.try_emplace(client.read_fd, std::move(session)); // Thread is seperately started after the moving operation to overcome the difficulty // in accessing class member variables inside the thread. // Class member variables gives unacceptable values if the thread starts before the move operation. itr->second.start_messaging_threads(); } outbound_clients.emplace(client.read_fd, std::move(client)); known_remotes.emplace(ipport); } } } } void comm_server::inbound_message_processor_loop(const SESSION_TYPE session_type) { util::mask_signal(); while (!should_stop_listening) { bool messages_processed = false; { // Process one message from each session in round-robin fashion. std::scoped_lock lock(sessions_mutex); for (auto &[fd, session] : sessions) { const int result = session.process_next_inbound_message(); if (result != 0) messages_processed = true; if (result == -1) session.mark_for_closure(); } } // If no messages were processed in this cycle, wait for some time. if (!messages_processed) util::sleep(10); } LOG_INFO << (session_type == SESSION_TYPE::USER ? "User" : "Peer") << " message processor stopped."; } int comm_server::start_websocketd_process( const uint16_t port, const char *domain_socket_name, const bool is_binary, const bool use_size_header, const uint64_t max_msg_size) { // setup pipe for firewall int firewall_pipe[2]; // parent to child pipe if (pipe(firewall_pipe)) { LOG_ERROR << errno << ": pipe() call failed for firewall"; } else { firewall_out = firewall_pipe[1]; } const pid_t pid = fork(); if (pid > 0) { // HotPocket process. // Close the child reading end of the pipe in the parent if (firewall_out > 0) close(firewall_pipe[0]); // Wait for some time and check if websocketd is still running properly. // Sending signal 0 to test whether process exist. util::sleep(20); if (util::kill_process(pid, false, 0) == -1) return -1; websocketd_pid = pid; } else if (pid == 0) { // Websocketd process. util::unmask_signal(); // We are using websocketd forked repo: https://github.com/codetsunami/websocketd if (firewall_out > 0) { // Close parent writing end of the pipe in the child close(firewall_pipe[1]); // Override stdin in the child's file table dup2(firewall_pipe[0], 0); } std::vector args_vec; args_vec.reserve(16); const std::string max_msg_size_str = (max_msg_size > 0) ? std::to_string(max_msg_size) : "134217728"; //128MB // Fill process args. args_vec.push_back(conf::ctx.websocketd_exe_path); args_vec.push_back("--port"); args_vec.push_back(std::to_string(port)); args_vec.push_back("--ssl"); args_vec.push_back("--sslcert"); args_vec.push_back(conf::ctx.tls_cert_file); args_vec.push_back("--sslkey"); args_vec.push_back(conf::ctx.tls_key_file); args_vec.push_back(is_binary ? "--binary=true" : "--binary=false"); args_vec.push_back(use_size_header ? "--sizeheader=true" : "--sizeheader=false"); args_vec.push_back(std::string("--maxframe=").append(max_msg_size_str)); args_vec.push_back("--loglevel=error"); args_vec.push_back("nc"); // netcat (OpenBSD) is used for domain socket redirection. args_vec.push_back("-U"); // Use UNIX domain socket args_vec.push_back(domain_socket_name); char *execv_args[args_vec.size()]; int idx = 0; for (std::string &arg : args_vec) execv_args[idx++] = arg.data(); execv_args[idx] = NULL; const int ret = execv(execv_args[0], execv_args); std::cerr << errno << ": websocketd process execv failed.\n"; exit(1); } else { LOG_ERROR << errno << ": fork() failed when starting websocketd process."; return -1; } return 0; } void comm_server::firewall_ban(std::string_view ip, const bool unban) { if (firewall_out < 0) return; iovec iov[]{ {(void *)(unban ? "r" : "a"), 1}, {(void *)ip.data(), ip.length()}}; writev(firewall_out, iov, 2); } /** * If the fd supplied was produced by accept()ing unix domain socket connection * the process at the other end is inspected for CGI environment variables * and the REMOTE_ADDR variable is returned as std::string, otherwise empty string */ std::string comm_server::get_cgi_ip(const int fd) { socklen_t length; ucred uc; length = sizeof(struct ucred); // Ask the operating system for information about the other process if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &uc, &length) == -1) { LOG_ERROR << errno << ": Could not retrieve PID from unix domain socket"; return ""; } // Open /proc//environ for that process std::stringstream ss; ss << "/proc/" << uc.pid << "/environ"; std::string fn = ss.str(); const int envfd = open(fn.c_str(), O_RDONLY | O_CLOEXEC); if (!envfd) { LOG_ERROR << errno << ": Could not open environ block for process on other end of unix domain socket PID=" << uc.pid; return ""; } // Read environ block char envblock[0x7fff]; const ssize_t bytes_read = read(envfd, envblock, 0x7fff); //0x7fff bytes is an operating system size limit for this block close(envfd); // Find the REMOTE_ADDR entry. Envrion block delimited by \0 for (char *upto = envblock, *last = envblock; upto - envblock < bytes_read; ++upto) { if (*upto == '\0') { if (upto - last > 12 && strncmp(last, "REMOTE_ADDR=", 12) == 0) return std::string((const char *)(last + 12)); last = upto + 1; } } LOG_ERROR << "Could not find REMOTE_ADDR variable in /proc/" << uc.pid << "/environ"; return ""; } void comm_server::stop() { should_stop_listening = true; watchdog_thread.join(); inbound_message_processor_thread.join(); if (websocketd_pid > 0) util::kill_process(websocketd_pid, false); // Kill websocketd. } } // namespace comm