diff --git a/CMakeLists.txt b/CMakeLists.txt index da3a515..d4b7d08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,8 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY build) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result -Wreturn-type") add_executable(sagent + src/comm/comm_handler.cpp + src/comm/comm_session.cpp src/util/util.cpp src/conf.cpp src/salog.cpp @@ -22,7 +24,12 @@ add_executable(sagent target_link_libraries(sagent sqlite3 + pthread ${CMAKE_DL_LIBS} # Needed for stacktrace support ) +add_custom_command(TARGET sagent POST_BUILD + COMMAND cp ./dependencies/bin/hpws ./build/ +) + target_precompile_headers(sagent PUBLIC src/pchheader.hpp) diff --git a/README.md b/README.md index 9c8112d..53e78b2 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,10 @@ Code is divided into subsystems via namespaces. **conf::** Handles configuration. Loads and holds the central configuration object. Used by most of the subsystems. +**salog::** Handles logging. Creates and prints the logs according to the configured log section in the json config. + +**comm::** Handles generic web sockets communication functionality. Mainly acts as a wrapper for [hpws](https://github.com/RichardAH/hpws). + **util::** Contains shared data structures/helper functions used by multiple subsystems. **sqlite::** Contains sqlite database management related helper functions. \ No newline at end of file diff --git a/dependencies/bin/hpws b/dependencies/bin/hpws new file mode 100755 index 0000000..69c98f5 Binary files /dev/null and b/dependencies/bin/hpws differ diff --git a/src/comm/comm_handler.cpp b/src/comm/comm_handler.cpp new file mode 100644 index 0000000..9b920e0 --- /dev/null +++ b/src/comm/comm_handler.cpp @@ -0,0 +1,80 @@ +#include "comm_handler.hpp" +#include "../util/util.hpp" +#include "hpws.hpp" +#include "comm_session.hpp" + +namespace comm +{ + constexpr uint32_t DEFAULT_MAX_MSG_SIZE = 5 * 1024 * 1024; + std::optional session; + bool init_success; + + int init() + { + if (connect(conf::cfg.server.ip_port) == -1) + return -1; + + init_success = true; + + return 0; + } + + void deinit() + { + if (init_success) + disconnect(); + } + + int connect(const conf::host_ip_port &ip_port) + { + std::string_view host = ip_port.host_address; + const uint16_t port = ip_port.port; + + LOG_DEBUG << "Trying to connect " << host << ":" << std::to_string(port); + + std::variant client_result = hpws::client::connect(conf::ctx.hpws_exe_path, DEFAULT_MAX_MSG_SIZE, host, port, "/", {}, util::fork_detach); + + if (std::holds_alternative(client_result)) + { + const hpws::error error = std::get(client_result); + if (error.first != 202) + LOG_ERROR << "Connection hpws error:" << error.first << " " << error.second; + return -1; + } + else + { + hpws::client client = std::move(std::get(client_result)); + const std::variant host_result = client.host_address(); + if (std::holds_alternative(host_result)) + { + const hpws::error error = std::get(host_result); + LOG_ERROR << "Error getting ip from hpws:" << error.first << " " << error.second; + return -1; + } + else + { + const std::string &host_address = std::get(host_result); + session.emplace(host_address, std::move(client)); + session->init(); + } + } + return 0; + } + + void disconnect() + { + if (session.has_value()) + { + session->close(); + session.reset(); + } + } + + /** + * Wait for the session. + */ + void wait() + { + session->wait(); + } +} // namespace comm diff --git a/src/comm/comm_handler.hpp b/src/comm/comm_handler.hpp new file mode 100644 index 0000000..f2e00dd --- /dev/null +++ b/src/comm/comm_handler.hpp @@ -0,0 +1,21 @@ +#ifndef _SA_COMM_COMM_SERVER_ +#define _SA_COMM_COMM_SERVER_ + +#include "../pchheader.hpp" +#include "../conf.hpp" + +namespace comm +{ + int init(); + + void deinit(); + + int connect(const conf::host_ip_port &ip_port); + + void disconnect(); + + void wait(); + +} // namespace comm + +#endif diff --git a/src/comm/comm_session.cpp b/src/comm/comm_session.cpp new file mode 100644 index 0000000..f050949 --- /dev/null +++ b/src/comm/comm_session.cpp @@ -0,0 +1,91 @@ +#include "../pchheader.hpp" +#include "../util/util.hpp" +#include "../conf.hpp" +#include "hpws.hpp" +#include "comm_session.hpp" + +namespace comm +{ + + comm_session::comm_session( + std::string_view host_address, hpws::client &&hpws_client) + : uniqueid(host_address), + host_address(host_address), + hpws_client(std::move(hpws_client)) + { + } + + /** + * Init() should be called to activate the session. + * Because we are starting threads here, after init() is called, the session object must not be "std::moved". + * @return returns 0 on successful init, otherwise -1; + */ + int comm_session::init() + { + if (state == SESSION_STATE::NONE) + { + reader_thread = std::thread(&comm_session::reader_loop, this); + state = SESSION_STATE::ACTIVE; + LOG_DEBUG << "Session started: " << uniqueid; + } + + return 0; + } + + void comm_session::reader_loop() + { + util::mask_signal(); + + while (state != SESSION_STATE::CLOSED && hpws_client) + { + const std::variant read_result = hpws_client->read(); + if (std::holds_alternative(read_result)) + { + const hpws::error error = std::get(read_result); + if (error.first != 1) // 1 indicates channel has closed. + LOG_DEBUG << "hpws client read failed:" << error.first << " " << error.second; + } + else + { + // Enqueue the message for processing. + std::string_view data = std::get(read_result); + + LOG_INFO << "Received message : " << data; + + // Signal the hpws client that we are ready for next message. + const std::optional error = hpws_client->ack(data); + if (error.has_value()) + LOG_DEBUG << "hpws client ack failed:" << error->first << " " << error->second; + } + } + } + + /** + * Close the connection and wrap up any session processing threads. + * This will be only called by the global comm_server thread. + */ + void comm_session::close() + { + if (state == SESSION_STATE::CLOSED) + return; + + state = SESSION_STATE::CLOSED; + + // Destruct the hpws client instance so it will close the sockets and related processes. + hpws_client.reset(); + + if (reader_thread.joinable()) + reader_thread.join(); + + LOG_DEBUG << "Session closed: " << uniqueid; + } + + /** + * Joins the listner thread. + */ + void comm_session::wait() + { + reader_thread.join(); + } + +} // namespace comm \ No newline at end of file diff --git a/src/comm/comm_session.hpp b/src/comm/comm_session.hpp new file mode 100644 index 0000000..ce1ebb8 --- /dev/null +++ b/src/comm/comm_session.hpp @@ -0,0 +1,42 @@ +#ifndef _HP_COMM_COMM_SESSION_ +#define _HP_COMM_COMM_SESSION_ + +#include "../pchheader.hpp" +#include "../conf.hpp" +#include "hpws.hpp" + +namespace comm +{ + enum SESSION_STATE + { + NONE, // Session is not yet initialized properly. + ACTIVE, // Session is active and functioning. + MUST_CLOSE, // Session socket is in unusable state and must be closed. + CLOSED // Session is fully closed. + }; + + /** + * Represents an active WebSocket connection + */ + class comm_session + { + private: + SESSION_STATE state = SESSION_STATE::NONE; + std::optional hpws_client; + const std::string uniqueid; // Verified session: Pubkey in hex format, Unverified session: IP address. + const std::string host_address; // Connection host address of the remote party. + std::thread reader_thread; // The thread responsible for reading messages from the read fd. + + void reader_loop(); + + public: + comm_session( + std::string_view host_address, hpws::client &&hpws_client); + int init(); + void close(); + void wait(); + }; + +} // namespace comm + +#endif diff --git a/src/comm/hpws.hpp b/src/comm/hpws.hpp new file mode 100644 index 0000000..737f538 --- /dev/null +++ b/src/comm/hpws.hpp @@ -0,0 +1,995 @@ +#ifndef HPWS_INCLUDE +#define HPWS_INCLUDE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define DECODE_O_SIZE(control_msg, into) \ + { \ + into = ((uint32_t)control_msg[2] << 24) + ((uint32_t)control_msg[3] << 16) + \ + ((uint32_t)control_msg[4] << 8) + ((uint32_t)control_msg[5] << 0); \ + } + +#define ENCODE_O_SIZE(control_msg, from) \ + { \ + uint32_t f = from; \ + control_msg[2] = (unsigned char)((f >> 24) & 0xff); \ + control_msg[3] = (unsigned char)((f >> 16) & 0xff); \ + control_msg[4] = (unsigned char)((f >> 8) & 0xff); \ + control_msg[5] = (unsigned char)((f >> 0) & 0xff); \ + } + +#define HPWS_DEBUG 0 + +namespace hpws +{ + /*typedef enum e_retcode { + SUCCESS + } retcode; + */ + using error = std::pair; + +// used when waiting for messages that should already be on the pipe +#define HPWS_SMALL_TIMEOUT 10 +// used when waiting for server process to spawn +#define HPWS_LONG_TIMEOUT 1500 // This timeout has to account the possible delays in communication via internet. + + typedef union + { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; + struct sockaddr_storage ss; + } addr_t; + + class server; + + class client + { + + private: + pid_t child_pid = 0; // if this client was created by a connect this is set + // this value can't be changed once it's established between the processes + uint32_t max_buffer_size; + bool moved = false; + addr_t endpoint; + std::string get; // the get req this websocket was opened with + int control_line_fd[2]; // see below in client constructor + int buffer_fd[4]; // 0 1 - in buffers, 2 3 - out buffers + int buffer_lock[2] = {0, 0}; // this records if buffers 2 and 3 have been sent out awaiting an ack or not + void *buffer[4]; + + // private constructor + client( + std::string_view get, + addr_t endpoint, + int control_line_fd_0, // hpws -> sagent [ hpws sends bufs to us over this line ] + int control_line_fd_1, // sagent -> hpws [ we send bufs to hpws over this line ] + uint32_t max_buffer_size, + pid_t child_pid, + int buffer_fd[4], + void *buffer[4]) : endpoint(endpoint), + max_buffer_size(max_buffer_size), + child_pid(child_pid), get(get) + { + control_line_fd[0] = control_line_fd_0; + control_line_fd[1] = control_line_fd_1; + for (int i = 0; i < 4; ++i) + { + this->buffer[i] = buffer[i]; + this->buffer_fd[i] = buffer_fd[i]; + } + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] child constructed pid = %d\n", child_pid); + } + + public: + // No copy constructor + client(const client &) = delete; + + // only a move constructor + client(client &&old) : child_pid(old.child_pid), + max_buffer_size(old.max_buffer_size), + endpoint(old.endpoint), + get(old.get) + { + old.moved = true; + for (int i = 0; i <= 1; ++i) + { + this->control_line_fd[i] = old.control_line_fd[i]; + buffer_lock[i] = old.buffer_lock[i]; + } + for (int i = 0; i < 4; ++i) + { + this->buffer[i] = old.buffer[i]; + this->buffer_fd[i] = old.buffer_fd[i]; + } + } + + ~client() + { + if (!moved) + { + + // RH TODO ensure this pid terminates by following up with a SIGKILL + if (child_pid > 0) + { + kill(child_pid, SIGTERM); + int status; + waitpid(child_pid, &status, 0 /* should we use WNOHANG? */); + } + + for (int i = 0; i < 4; ++i) + { + munmap(buffer[i], max_buffer_size); + ::close(buffer_fd[i]); + } + + ::close(control_line_fd[0]); + ::close(control_line_fd[1]); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] child destructed pid = %d\n", child_pid); + } + } + + const std::variant host_address() + { + char hostname[NI_MAXHOST]; + const int ret = getnameinfo((sockaddr *)&endpoint, sizeof(sockaddr), hostname, sizeof(hostname), NULL, 0, NI_NUMERICHOST); + if (ret != 0) + return error{10, gai_strerror(ret)}; + + return hostname; + } + + std::variant read() + { + + unsigned char buf[32]; + int bytes_read = recv(control_line_fd[0], buf, sizeof(buf), 0); + if (bytes_read < 1) + { + if (HPWS_DEBUG) + { + perror("recv"); + fprintf(stderr, "[HPWS.HPP] bytes received %d\n", bytes_read); + } + return error{1, "[read] control line could not be read"}; // todo clean up somehow? + } + + switch (buf[0]) + { + case 'o': + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] o message received\n"); + + if (bytes_read != 6) + return error{3, "invalid buffer in 'o' command sent by hpws"}; + + uint32_t len = 0; + DECODE_O_SIZE(buf, len); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] o message len: %u\n", len); + + int bufno = buf[1] - '0'; + if (bufno != 0 && bufno != 1) + return error{3, "invalid buffer in 'o' command sent by hpws"}; + + if (HPWS_DEBUG) + { + fprintf(stderr, "[HPWS.HPP] read %d\n", len); + for (uint32_t i = 0; i < len; ++i) + putc(((char *)(buffer[bufno]))[i], stderr); + fprintf(stderr, "\n---\n"); + } + return std::string_view{(const char *)(buffer[bufno]), len}; + } + case 'c': + { + return error{1000, "ws closed"}; + } + default: + { + fprintf(stderr, "[HPWS.HPP] read unknown control message 1: `%.*s`\n", bytes_read, buf); + return error{2, "unknown control line command was sent by hpws"}; + } + } + } + + void close() + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] close called\n"); + + // send the control message informing hpws that we wish to close + char buf[1] = {'c'}; + + ::write(control_line_fd[1], buf, 1); + + // wait for the process to end gracefully + int status; + printf("waitpid result: %d\n", waitpid(child_pid, &status, 0)); // add timeout here? + } + + std::optional write(std::string_view to_write) + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] write called for message: `%.*s`\n", + (int)to_write.size(), to_write.data()); + + // check if we have any free buffers + if (buffer_lock[0] && buffer_lock[1]) + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] no free buffers while writing, waiting for an ack on control_line_fd[1]=%d\n", control_line_fd[1]); + + // no free buffers, wait for a ack + unsigned char buf[32]; + int bytes_read = 0; + + bytes_read = recv(control_line_fd[1], buf, sizeof(buf), 0); + if (bytes_read < 1) + { + perror("recv"); + return error{1, "[write] control line could not be read"}; // todo clean up somehow? + } + + switch (buf[0]) + { + case 'a': + { + if (bytes_read != 2) + return error{4, "received an ack longer than 2 bytes"}; + int bufno = buf[1] - '0'; + if (!(bufno == 0 || bufno == 1)) + return error{5, "received an ack with an invalid buffer, expecting 0 or 1"}; + // unlock the buffer + buffer_lock[bufno] = 0; + break; + } + case 'c': + return error{1000, "ws closed"}; + default: + fprintf(stderr, "[HPWS.HPP] read unknown control message 2: `%.*s`\n", bytes_read, buf); + return error{2, "unknown control line command was sent by hpws"}; + } + } + + // execution to here ensures at least one buffer is free + int bufno = (buffer_lock[0] == 0 ? 2 : 3); + + // update the selected buffer lock + buffer_lock[bufno - 2] = 1; + + // write into the buffer + memcpy(buffer[bufno], to_write.data(), to_write.size()); + + // send the control message informing hpws that a message is ready on this buffer + uint32_t len = to_write.size(); + char buf[6] = {'o', (char)('0' + (bufno - 2)), 0, 0, 0, 0}; + ENCODE_O_SIZE(buf, len); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] writing 'o' to control_line_fd[1]=%d\n", + control_line_fd[1]); + + if (::write(control_line_fd[1], buf, 6) != 6) + return error{6, "could not write o message to control line"}; + + return std::nullopt; + } + + std::optional ack(std::string_view from_read) + { + char msg[2] = {'a', '0'}; + if (from_read.data() == buffer[1]) + msg[1]++; + if (send(control_line_fd[0], msg, 2, 0) < 2) + return error{10, "could not send ack down control line"}; + return std::nullopt; + } + + static std::variant connect( + std::string_view bin_path, + uint32_t max_buffer_size, + std::string_view host, + uint16_t port, + std::string_view get, + std::vector argv, + std::function fork_child_init = NULL) + { + +#define HPWS_CONNECT_ERROR(code, msg) \ + { \ + error_code = code; \ + error_msg = msg; \ + goto connect_error; \ + } + + int error_code = -1; + const char *error_msg = NULL; + int fd[4] = {-1, -1, -1, -1}; // 0,1 are hpws->sagent, 2,3 are sagent->hpws + int buffer_fd[4] = {-1, -1, -1, -1}; + void *mapping[4] = {NULL, NULL, NULL, NULL}; + int pid = -1; + int count_args = 14 + argv.size(); + char const **argv_pass = NULL; + + if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd)) + HPWS_CONNECT_ERROR(100, "could not create unix domain socket pair"); + + if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd + 2)) + HPWS_CONNECT_ERROR(101, "could not create unix domain socket pair"); + + // construct the arguments + char shm_size[32]; + + if (snprintf(shm_size, 32, "%d", max_buffer_size) <= 0) + HPWS_CONNECT_ERROR(90, "couldn't write shm size to string"); + + char port_str[6]; + if (snprintf(port_str, 6, "%d", port) <= 0) + HPWS_CONNECT_ERROR(91, "couldn't write port to string"); + + char cfd1[10]; + char cfd2[10]; + + snprintf(cfd1, 10, "%d", fd[1]); + snprintf(cfd2, 10, "%d", fd[3]); + + argv_pass = + reinterpret_cast(alloca(sizeof(char *) * count_args)); + { + int upto = 0; + argv_pass[upto++] = bin_path.data(); + argv_pass[upto++] = "--client"; + argv_pass[upto++] = "--maxmsg"; + argv_pass[upto++] = shm_size; + argv_pass[upto++] = "--host"; + argv_pass[upto++] = host.data(); + argv_pass[upto++] = "--port"; + argv_pass[upto++] = port_str; + argv_pass[upto++] = "--cntlfd"; + argv_pass[upto++] = cfd1; + argv_pass[upto++] = "--cntlfd2"; + argv_pass[upto++] = cfd2; + argv_pass[upto++] = "--get"; + argv_pass[upto++] = get.data(); + for (std::string_view &arg : argv) + argv_pass[upto++] = arg.data(); + argv_pass[upto] = NULL; + } + + pid = vfork(); + + if (pid) + { + + // --- PARENT + + // Fds are set to -1, so when error occurred these fds won't get closed again. + ::close(fd[1]); + fd[1] = -1; + ::close(fd[3]); + fd[3] = -1; + + int child_fd[2] = {fd[0], fd[2]}; + + int flags[2] = { + fcntl(child_fd[0], F_GETFD, NULL), + fcntl(child_fd[1], F_GETFD, NULL)}; + + if (flags[0] < 0 || flags[1] < 0) + HPWS_CONNECT_ERROR(101, "could not get flags from unix domain socket"); + + flags[0] |= FD_CLOEXEC; + flags[1] |= FD_CLOEXEC; + if (fcntl(child_fd[0], F_SETFD, flags[0]) || fcntl(child_fd[1], F_SETFD, flags[1])) + HPWS_CONNECT_ERROR(102, "could notset flags for unix domain socket"); + + // we will set a timeout and wait for the initial startup message from hpws client mode + struct pollfd pfd; + int ret; + + pfd.fd = child_fd[0]; // we receive setup events on control line 0 (hpws->sagent) + pfd.events = POLLIN; + ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + + // timeout or error + if (ret < 1) + HPWS_CONNECT_ERROR(1, "timeout waiting for hpws connect message"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] waiting for addr_t\n"); + // first thing we'll receive is the sockaddr union + addr_t child_addr; + + int bytes_read = + recv(child_fd[0], (unsigned char *)(&child_addr), sizeof(child_addr), 0); + + if (bytes_read < sizeof(child_addr)) + HPWS_CONNECT_ERROR(202, "received message on control line was not sizeof(addr_t)"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] waiting for buffer fds\n"); + + // second thing we will receive is the four fds for the buffers + { + struct msghdr child_msg = {0}; + memset(&child_msg, 0, sizeof(child_msg)); + char cmsgbuf[CMSG_SPACE(sizeof(int) * 4)]; + child_msg.msg_control = cmsgbuf; + child_msg.msg_controllen = sizeof(cmsgbuf); + + int bytes_read = + recvmsg(child_fd[0], &child_msg, 0); + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg); + if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) + HPWS_CONNECT_ERROR(203, "non-scm_rights message sent on accept child control line"); + memcpy(&buffer_fd, CMSG_DATA(cmsg), sizeof(buffer_fd)); + for (int i = 0; i < 4; ++i) + { + //fprintf(stderr, "scm passed buffer_fd[%d] = %d\n", i, buffer_fd[i]); + if (buffer_fd[i] < 0) + HPWS_CONNECT_ERROR(203, "child accept scm_rights a passed buffer fd was negative"); + mapping[i] = + mmap(0, max_buffer_size, PROT_READ | PROT_WRITE, MAP_SHARED, buffer_fd[i], 0); + if (mapping[i] == (void *)(-1)) + HPWS_CONNECT_ERROR(204, "could not mmap scm_rights passed buffer fd"); + } + } + + for (int i = 0; i <= 1; ++i) + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] waiting for 'r' on child_fd[%d]=%d\n", i, child_fd[i]); + + pfd.fd = child_fd[i]; + // now we wait for a 'r' ready message or for the socket/client to die + ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + + char rbuf[2]; + bytes_read = recv(child_fd[i], rbuf, sizeof(rbuf), 0); + if (bytes_read < 1) + HPWS_CONNECT_ERROR(2, "nil message sent by hpws on startup"); + + if (rbuf[0] != 'r') + HPWS_CONNECT_ERROR(3, "unexpected content in message sent by hpws client mode on startup"); + + if (rbuf[1] != '0' + i) + HPWS_CONNECT_ERROR(4, "received wrong r message on control fd"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] received 'r%c' on child_fd[%d]=%d\n", rbuf[1], i, child_fd[i]); + } + + return client{ + get, + child_addr, + child_fd[0], + child_fd[1], + max_buffer_size, + pid, + buffer_fd, + mapping}; + } + else + { + + // --- CHILD + if (fork_child_init) + fork_child_init(); + + ::close(fd[0]); + ::close(fd[2]); + + // dup fd[1] into fd 3 + /*if (dup2(fd[1], 3) == -1) + perror("dup2 fd[1]"); + if (dup2(fd[3], 4) == -1) + perror("dup2 fd[3]"); + */ + // ::close(fd[1]); + // ::close(fd[3]); + + // we're assuming all fds above 3 will have close_exec flag + execv(bin_path.data(), (char *const *)argv_pass); + // we will send a nil message down the pipe to help the parent know somethings gone wrong + char nil[1]; + nil[0] = 0; + send(3, nil, 1, 0); + exit(1); // execl failure as child will always result in exit here + } + + connect_error:; + + // NB: execution to here can only happen in parent process + // clean up any mess after error + if (pid > 0) + { + kill((pid_t)pid, SIGKILL); /* RH TODO change this to SIGTERM and set a timeout? */ + int status; + waitpid(pid, &status, 0 /* should we use WNOHANG? */); + } + for (int i = 0; i < 4; ++i) + { + if (fd[i] > 0) + ::close(fd[i]); + if (mapping[i] != MAP_FAILED && mapping[i] != NULL) + munmap(mapping[i], max_buffer_size); + if (buffer_fd[i] > -1) + ::close(buffer_fd[i]); + } + + return error{error_code, std::string{error_msg}}; + } + friend class server; + }; + + class server + { + + private: + pid_t server_pid_; + int master_control_fd_; + uint32_t max_buffer_size_; + bool moved = false; + + // private constructor + server(pid_t server_pid, int master_control_fd, uint32_t max_buffer_size) + : server_pid_(server_pid), master_control_fd_(master_control_fd), max_buffer_size_(max_buffer_size) {} + + void accept_cleanup(void *mapping[4], int child_fd[2], int buffer_fd[4], uint32_t pid_child) + { + for (int i = 0; i < 4; i++) + { + if (mapping[i] != MAP_FAILED && mapping[i] != NULL) + munmap(mapping[i], max_buffer_size_); + if (i < 2 && child_fd[i] > -1) + ::close(child_fd[i]); + if (buffer_fd[i] > -1) + ::close(buffer_fd[i]); + } + + if (pid_child > 0) + { + int ret1 = kill(pid_child, SIGTERM); + int wstat; + int ret2 = waitpid(pid_child, &wstat, 0); + } + } + + public: + // No copy constructor + server(const server &) = delete; + + // only a move constructor + server(server &&old) : server_pid_(old.server_pid_), + master_control_fd_(old.master_control_fd_), + max_buffer_size_(old.max_buffer_size_) + { + old.moved = true; + } + + pid_t server_pid() + { + return server_pid_; + } + + int master_control_fd() + { + return master_control_fd_; + } + + uint32_t max_buffer_size() + { + return max_buffer_size_; + } + + std::variant accept(const bool no_block = false) + { + + static int calls = 0; + ++calls; + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[0] called %d\n", calls); + +#define HPWS_ACCEPT_ERROR(code, msg) \ + { \ + accept_cleanup(mapping, child_fd, buffer_fd, pid); \ + return error{code, msg}; \ + } + + int child_fd[2] = {-1, -1}; + int buffer_fd[4] = {-1, -1, -1, -1}; + void *mapping[4] = {NULL, NULL, NULL, NULL}; + // must not use pid_t here since we transfer across IPC channel as a uint32. + uint32_t pid = 0; + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[1] called %d\n", calls); + { + struct msghdr child_msg = {0}; + memset(&child_msg, 0, sizeof(child_msg)); + char cmsgbuf[CMSG_SPACE(sizeof(int) * 2)]; + child_msg.msg_control = cmsgbuf; + child_msg.msg_controllen = sizeof(cmsgbuf); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[2] called %d\n", calls); + + // If no-block is specified, we first check any bytes available on control fd + // before attempting to do a blocking a read. + if (no_block) + { + struct pollfd master_pfd; + master_pfd.fd = this->master_control_fd_; + master_pfd.events = POLLERR | POLLHUP | POLLNVAL | POLLIN; + const int master_poll_result = poll(&master_pfd, 1, HPWS_SMALL_TIMEOUT); + + if (master_poll_result == -1) // 1 ms timeout + HPWS_ACCEPT_ERROR(200, "poll failed on master control line"); + + if (master_poll_result == 0) // No data available + HPWS_ACCEPT_ERROR(199, "no new client available"); + } + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[3] called %d\n", calls); + + int bytes_read = + recvmsg(this->master_control_fd_, &child_msg, 0); + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg); + if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) + HPWS_ACCEPT_ERROR(200, "non-scm_rights message sent on master control line"); + memcpy(&child_fd, CMSG_DATA(cmsg), sizeof(child_fd)); + if (child_fd[0] < 0 || child_fd[1] < 0) + HPWS_ACCEPT_ERROR(201, "scm_rights passed fd/s were negative"); + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] On accept received SCM: child_fd[0] = %d, child_fd[1] = %d\n", + child_fd[0], child_fd[1]); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[4] called %d\n", calls); + } + + // read info from child control line with a timeout + struct pollfd pfd; + int ret; + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[5] called %d\n", calls); + + pfd.fd = child_fd[0]; // expect all setup messages on the hpws->sagent controlfd (0) + pfd.events = POLLIN; + ret = poll(&pfd, 1, HPWS_SMALL_TIMEOUT); // 1 ms timeout + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[6] called %d\n", calls); + + // timeout or error + if (ret < 1) + HPWS_ACCEPT_ERROR(202, "timeout waiting for hpws accept child message"); + + // first thing we'll receive is the pid of the client + if (recv(child_fd[0], (unsigned char *)(&pid), sizeof(pid), 0) < sizeof(pid)) + HPWS_ACCEPT_ERROR(212, "did not receive expected 4 byte pid of child process on accept"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[7] called %d\n", calls); + + // second thing we'll receive is IP address structure of the client + addr_t buf; + int bytes_read = + recv(child_fd[0], (unsigned char *)(&buf), sizeof(buf), 0); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[8] called %d\n", calls); + + if (bytes_read < sizeof(buf)) + HPWS_ACCEPT_ERROR(202, "received message on master control line was not sizeof(sockaddr_in6)"); + + // third thing we will receive is the four fds for the buffers + { + struct msghdr child_msg = {0}; + memset(&child_msg, 0, sizeof(child_msg)); + char cmsgbuf[CMSG_SPACE(sizeof(int) * 4)]; + child_msg.msg_control = cmsgbuf; + child_msg.msg_controllen = sizeof(cmsgbuf); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[9] called %d\n", calls); + + int bytes_read = + recvmsg(child_fd[0], &child_msg, 0); + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg); + if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) + HPWS_ACCEPT_ERROR(203, "non-scm_rights message sent on accept child control line"); + memcpy(&buffer_fd, CMSG_DATA(cmsg), sizeof(buffer_fd)); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[10] called %d\n", calls); + + for (int i = 0; i < 4; ++i) + { + //fprintf(stderr, "scm passed buffer_fd[%d] = %d\n", i, buffer_fd[i]); + if (buffer_fd[i] < 0) + HPWS_ACCEPT_ERROR(203, "child accept scm_rights a passed buffer fd was negative"); + mapping[i] = + mmap(0, max_buffer_size_, PROT_READ | PROT_WRITE, MAP_SHARED, buffer_fd[i], 0); + if (mapping[i] == MAP_FAILED) + HPWS_ACCEPT_ERROR(204, "could not mmap scm_rights passed buffer fd"); + } + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[11] called %d\n", calls); + } + { + struct pollfd pfd; + int ret; + + for (int i = 0; i <= 1; ++i) + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] waiting for 'r' on child_fd[%d]=%d accept\n", i, child_fd[i]); + pfd.fd = child_fd[i]; + pfd.events = POLLERR | POLLHUP | POLLNVAL | POLLIN; + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[12] called %d\n", calls); + + // now we wait for a 'r' ready message or for the socket/client to die + ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + + if (!(pfd.revents & POLLIN)) + HPWS_ACCEPT_ERROR(5, "could not read from client_fd"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[12a] called %d - ret %d\n", calls, ret); + char rbuf[2]; + bytes_read = recv(child_fd[i], rbuf, sizeof(rbuf), 0); + if (bytes_read < 1) + HPWS_ACCEPT_ERROR(2, "nil message sent by hpws on startup on accept"); + + if (rbuf[0] != 'r') + HPWS_ACCEPT_ERROR(3, "unexpected content in message sent by hpws client mode on startup"); + + if (rbuf[1] != '0' + i) + HPWS_ACCEPT_ERROR(4, "received wrong r message on control line fd"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] 'r%c' received on child_fd[%d]=%d\n", rbuf[1], i, child_fd[i]); + } + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] Accept[13] called %d\n", calls); + } + + return client{ + "", + buf, + child_fd[0], + child_fd[1], + max_buffer_size_, + (pid_t)pid, + buffer_fd, + mapping}; + } + + ~server() + { + if (!moved) + { + + // RH TODO ensure this pid terminates by following up with a SIGKILL + if (server_pid_ > 0) + { + kill(server_pid_, SIGTERM); + int status; + waitpid(server_pid_, &status, 0 /* should we use WNOHANG? */); + } + + ::close(master_control_fd_); + } + } + + static std::variant create( + std::string_view bin_path, + uint32_t max_buffer_size, + uint16_t port, + uint32_t max_con, + uint16_t max_con_per_ip, + std::string_view cert_path, + std::string_view key_path, + std::vector argv, //additional_arguments + std::function fork_child_init = NULL) + { +#define HPWS_SERVER_ERROR(code, msg) \ + { \ + error_code = code; \ + error_msg = msg; \ + goto server_error; \ + } + + int error_code = -1; + const char *error_msg = NULL; + int fd[2] = {-1, -1}; + pid_t pid = -1; + int count_args = 17 + argv.size(); + char const **argv_pass = NULL; + + if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd)) + HPWS_SERVER_ERROR(100, "could not create unix domain socket pair"); + + // construct the arguments + char shm_size[32]; + + if (snprintf(shm_size, 32, "%d", max_buffer_size) <= 0) + HPWS_SERVER_ERROR(90, "couldn't write shm size to string"); + + char port_str[6]; + if (snprintf(port_str, 6, "%d", port) <= 0) + HPWS_SERVER_ERROR(91, "couldn't write port to string"); + + char max_con_str[11]; + if (snprintf(max_con_str, 11, "%d", max_con) <= 0) + HPWS_SERVER_ERROR(92, "couldn't write max_con to string"); + + char max_con_per_ip_str[6]; + if (snprintf(max_con_per_ip_str, 6, "%d", max_con_per_ip) <= 0) + HPWS_SERVER_ERROR(93, "couldn't write max_con_per_ip to string"); + + argv_pass = + reinterpret_cast(alloca(sizeof(char *) * count_args)); + { + int upto = 0; + argv_pass[upto++] = bin_path.data(); + argv_pass[upto++] = "--server"; + argv_pass[upto++] = "--maxmsg"; + argv_pass[upto++] = shm_size; + argv_pass[upto++] = "--port"; + argv_pass[upto++] = port_str; + argv_pass[upto++] = "--cert"; + argv_pass[upto++] = cert_path.data(); + argv_pass[upto++] = "--key"; + argv_pass[upto++] = key_path.data(); + argv_pass[upto++] = "--cntlfd"; + argv_pass[upto++] = "3"; + argv_pass[upto++] = "--maxcon"; + argv_pass[upto++] = max_con_str; + argv_pass[upto++] = "--maxconip"; + argv_pass[upto++] = max_con_per_ip_str; + for (std::string_view &arg : argv) + argv_pass[upto++] = arg.data(); + argv_pass[upto] = NULL; + } + + pid = vfork(); + if (pid) + { + + // --- PARENT + + // Fds are set to -1, so when error occurred these fds won't get closed again. + ::close(fd[1]); + fd[1] = -1; + + int flags = fcntl(fd[0], F_GETFD, NULL); + if (flags < 0) + HPWS_SERVER_ERROR(101, "could not get flags from unix domain socket"); + + flags |= FD_CLOEXEC; + if (fcntl(fd[0], F_SETFD, flags)) + HPWS_SERVER_ERROR(102, "could notset flags for unix domain socket"); + + // we will set a timeout and wait for the initial startup message from hpws server mode + struct pollfd pfd; + int ret; + + pfd.fd = fd[0]; + pfd.events = POLLIN; + ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + + // timeout or error + if (ret < 1) + HPWS_SERVER_ERROR(1, "timeout waiting for hpws startup message"); + + char buf[1024]; + int bytes_read = recv(fd[0], buf, sizeof(buf) - 1, 0); + if (bytes_read < 1) + { + int status; + // Wait and obtain exit status code of hpws. + if (waitpid(pid, &status, 0) > 0) + { + switch (WEXITSTATUS(status)) + { + case 70: + HPWS_SERVER_ERROR(31, "Could not create listen socket."); + + case 72: + HPWS_SERVER_ERROR(32, "Could not bind socket for listen."); + + case 74: + HPWS_SERVER_ERROR(33, "Listen() failed."); + + default: + break; + } + } + HPWS_SERVER_ERROR(2, "nil message sent by hpws on startup"); + } + + buf[bytes_read] = '\0'; + if (strncmp(buf, "startup", 7) != 0) + { + fprintf(stderr, "startup message: `%.*s`\n", bytes_read, buf); + HPWS_SERVER_ERROR(3, "unexpected content in message sent by hpws on startup"); + } + return server{ + pid, + fd[0], + max_buffer_size}; + } + else + { + + // --- CHILD + if (fork_child_init) + fork_child_init(); + + ::close(fd[0]); + + // dup fd[1] into fd 3 + dup2(fd[1], 3); + ::close(fd[1]); + + // we're assuming all fds above 3 will have close_exec flag + execv(bin_path.data(), (char *const *)argv_pass); + // we will send a nil message down the pipe to help the parent know somethings gone wrong + char nil[1]; + nil[0] = 0; + send(3, nil, 1, 0); + exit(1); // execl failure as child will always result in exit here + } + + server_error:; + + // NB: execution to here can only happen in parent process + // clean up any mess after error + if (pid > 0) + { + kill(pid, SIGKILL); /* RH TODO change this to SIGTERM and set a timeout? */ + int status; + waitpid(pid, &status, 0 /* should we use WNOHANG? */); + } + if (fd[0] > 0) + ::close(fd[0]); + if (fd[1] > 0) + ::close(fd[1]); + + return error{error_code, std::string{error_msg}}; + } + }; +} // namespace hpws + +#endif diff --git a/src/conf.cpp b/src/conf.cpp index 3807030..60b0559 100644 --- a/src/conf.cpp +++ b/src/conf.cpp @@ -28,17 +28,6 @@ namespace conf return 0; } - /** - * Cleanup any resources. - */ - void deinit() - { - if (init_success) - { - // Deinit here. - } - } - /** * Create config here. * @return 0 for success. -1 for failure. @@ -70,6 +59,7 @@ namespace conf sa_config cfg = {}; cfg.version = "0.0.1"; + cfg.server.ip_port = {}; cfg.log.max_file_count = 50; cfg.log.max_mbytes_per_file = 10; cfg.log.log_level = "inf"; @@ -87,25 +77,28 @@ namespace conf } /** - * Updates the context with directory paths based on provided base directory. + * Updates the context with directory paths based on provided executable path. * This is called after parsing SA command line arg in order to populate the ctx. - * @param basedir Path to base directory. + * @param exepath Path to executable. */ - void set_dir_paths(std::string basedir) + void set_dir_paths(std::string exepath) { - if (basedir.empty()) + if (exepath.empty()) { // This code branch will never execute the way main is currently coded, but it might change in future - std::cerr << "Base directory must be specified\n"; + std::cerr << "Executeble path must be specified\n"; exit(1); } - // Resolving the path through realpath will remove any trailing slash if present - // Set config directory to the parent of the exec binary. - basedir = dirname((char *)util::realpath(basedir).data()); - ctx.config_dir = basedir + "/cfg"; + // resolving the path through realpath will remove any trailing slash if present + exepath = util::realpath(exepath); + + // Take the parent directory path. + ctx.exe_dir = dirname(exepath.data()); + ctx.hpws_exe_path = ctx.exe_dir + "/" + "hpws"; + ctx.config_dir = ctx.exe_dir + "/cfg"; ctx.config_file = ctx.config_dir + "/sa.cfg"; - ctx.log_dir = basedir + "/log"; + ctx.log_dir = ctx.exe_dir + "/log"; } /** @@ -114,15 +107,19 @@ namespace conf */ int validate_dir_paths() { - const std::string paths[2] = { + const std::string paths[3] = { ctx.config_file, - ctx.log_dir}; + ctx.log_dir, + ctx.hpws_exe_path}; for (const std::string &path : paths) { if (!util::is_file_exists(path) && !util::is_dir_exists(path)) { - std::cerr << path << " does not exist.\n"; + if (path == ctx.hpws_exe_path) + std::cerr << path << " binary does not exist.\n"; + else + std::cerr << path << " does not exist.\n"; return -1; } } @@ -179,6 +176,31 @@ namespace conf std::string jpath; + // server + { + jpath = "server"; + + try + { + const jsoncons::ojson &server = d["server"]; + + cfg.server.ip_port.host_address = server["host"].as(); + cfg.server.ip_port.port = server["port"].as(); + + // Push the peer address and the port to peers set + if (cfg.server.ip_port.host_address.empty()) + { + std::cerr << "Configured server host_address is empty.\n"; + return -1; + } + } + catch (const std::exception &e) + { + print_missing_field_error(jpath, e); + return -1; + } + } + // log { jpath = "log"; @@ -217,6 +239,16 @@ namespace conf jsoncons::ojson d; d.insert_or_assign("version", cfg.version); + // Server configs. + { + jsoncons::ojson server_config; + + server_config.insert_or_assign("host", cfg.server.ip_port.host_address); + server_config.insert_or_assign("port", cfg.server.ip_port.port); + + d.insert_or_assign("server", server_config); + } + // Log configs. { jsoncons::ojson log_config; diff --git a/src/conf.hpp b/src/conf.hpp index 0ae6d62..ec7143a 100644 --- a/src/conf.hpp +++ b/src/conf.hpp @@ -14,6 +14,32 @@ namespace conf ERROR }; + struct host_ip_port + { + std::string host_address; + uint16_t port = 0; + + bool operator==(const host_ip_port &other) const + { + return host_address == other.host_address && port == other.port; + } + + bool operator!=(const host_ip_port &other) const + { + return !(host_address == other.host_address && port == other.port); + } + + bool operator<(const host_ip_port &other) const + { + return (host_address == other.host_address) ? port < other.port : host_address < other.host_address; + } + + const std::string to_string() const + { + return host_address + ":" + std::to_string(port); + } + }; + struct log_config { std::string log_level; // Log severity level (dbg, inf, wrn, wrr) @@ -23,19 +49,27 @@ namespace conf size_t max_file_count = 0; // Max no. of log files to keep. }; + struct server_config + { + host_ip_port ip_port; + }; + struct sa_config { std::string version; + server_config server; log_config log; }; struct sa_context { - std::string command; // The CLI command issued to launch Sashimono agent + std::string command; // The CLI command issued to launch Sashimono agent + std::string exe_dir; // Hot Pocket executable dir. + std::string hpws_exe_path; // hpws executable file path. - std::string config_dir; // Config dir full path. - std::string config_file; // Full path to the config file. - std::string log_dir; // Log directory full path. + std::string config_dir; // Config dir full path. + std::string config_file; // Full path to the config file. + std::string log_dir; // Log directory full path. }; // Global context struct exposed to the application. @@ -48,11 +82,9 @@ namespace conf int init(); - void deinit(); - int create(); - void set_dir_paths(std::string basedir); + void set_dir_paths(std::string exepath); int validate_dir_paths(); diff --git a/src/main.cpp b/src/main.cpp index d67aacf..077971f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -5,6 +5,7 @@ #include "conf.hpp" #include "sqlite.hpp" #include "salog.hpp" +#include "comm/comm_handler.hpp" /** * Parses CLI args and extracts sashimono agent command and parameters given. @@ -14,13 +15,16 @@ */ int parse_cmd(int argc, char **argv) { - conf::ctx.command = argv[1]; - if (argc == 2 && //We get working dir as an arg anyway. So we need to check for ==2 args. - (conf::ctx.command == "new" || conf::ctx.command == "run" || conf::ctx.command == "version")) + if (argc > 1) { - // We populate the global contract ctx with the detected command. - conf::set_dir_paths(argv[0]); - return 0; + conf::ctx.command = argv[1]; + if (argc == 2 && //We get working dir as an arg anyway. So we need to check for ==2 args. + (conf::ctx.command == "new" || conf::ctx.command == "run" || conf::ctx.command == "version")) + { + // We populate the global contract ctx with the detected command. + conf::set_dir_paths(argv[0]); + return 0; + } } // If all extractions fail display help message. @@ -38,11 +42,70 @@ int parse_cmd(int argc, char **argv) */ void deinit() { - conf::deinit(); + comm::deinit(); +} + +void sig_exit_handler(int signum) +{ + LOG_WARNING << "Interrupt signal (" << signum << ") received."; + deinit(); + LOG_WARNING << "sagent exited due to signal."; + exit(signum); +} + +void segfault_handler(int signum) +{ + exit(SIGABRT); +} + +/** + * Global exception handler for std exceptions. + */ +void std_terminate() noexcept +{ + const std::exception_ptr exptr = std::current_exception(); + if (exptr != 0) + { + try + { + std::rethrow_exception(exptr); + } + catch (std::exception &ex) + { + LOG_ERROR << "std error: " << ex.what(); + } + catch (...) + { + LOG_ERROR << "std error: Terminated due to unknown exception"; + } + } + else + { + LOG_ERROR << "std error: Terminated due to unknown reason"; + } + + exit(1); } int main(int argc, char **argv) { + // Register exception and segfault handlers. + std::set_terminate(&std_terminate); + signal(SIGSEGV, &segfault_handler); + signal(SIGABRT, &segfault_handler); + + // Become a sub-reaper so we can gracefully reap hpws child processes via hpws.hpp. + // (Otherwise they will get reaped by OS init process and we'll end up with race conditions with gracefull kills) + prctl(PR_SET_CHILD_SUBREAPER, 1); + + // Disable SIGPIPE to avoid crashing on broken pipe IO. + { + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, SIGPIPE); + pthread_sigmask(SIG_BLOCK, &mask, NULL); + } + // Extract the CLI args // This call will populate conf::ctx if (parse_cmd(argc, argv) != 0) @@ -65,40 +128,26 @@ int main(int argc, char **argv) { LOG_INFO << "Sashimono agent started. Version : " << conf::cfg.version << " Log level : " << conf::cfg.log.log_level; - // Run the program. - - sqlite3 *db = NULL; - const char *path = "db.sqlite"; - - if (sqlite::open_db(path, &db, true) == -1) + if (comm::init() == -1) { - LOG_ERROR << "Error opening database"; + deinit(); return -1; } - LOG_INFO << "Database " << path << " opened successfully"; - const std::vector column_info{ - sqlite::table_column_info("VERSION", sqlite::COLUMN_DATA_TYPE::TEXT)}; + // After initializing primary subsystems, register the exit handler. + signal(SIGINT, &sig_exit_handler); + signal(SIGTERM, &sig_exit_handler); - if (create_table(db, "SA_VERSION", column_info) == -1) - return -1; + // Waiting for the websocket sessions. + comm::wait(); - if (sqlite::insert_row(db, "SA_VERSION", "VERSION", "\"0.0.0\"") == -1) - return -1; + deinit(); - if (sqlite::close_db(&db) == -1) - { - LOG_ERROR << "Error closing database"; - return -1; - } + LOG_INFO << "sashimono agent exited normally."; } else if (conf::ctx.command == "version") - // Print the version LOG_INFO << "Sashimono Agent " << conf::cfg.version; - deinit(); + return 0; } - - LOG_INFO << "sashimono agent exited normally."; - return 0; -} +} \ No newline at end of file diff --git a/src/pchheader.hpp b/src/pchheader.hpp index df888d0..b877602 100644 --- a/src/pchheader.hpp +++ b/src/pchheader.hpp @@ -2,6 +2,8 @@ #define _SA_PCHHEADER_ #include +#include +#include #include #include #include @@ -9,10 +11,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include diff --git a/src/util/util.cpp b/src/util/util.cpp index f161724..cf78346 100644 --- a/src/util/util.cpp +++ b/src/util/util.cpp @@ -97,4 +97,35 @@ namespace util return buffer.data(); } + /** + * Clears signal mask and signal handlers from the caller. + * Called by other processes forked from sagent threads so they get detatched from + * the sagent signal setup. + */ + void fork_detach() + { + // Restore signal handlers to defaults. + signal(SIGINT, SIG_DFL); + signal(SIGSEGV, SIG_DFL); + signal(SIGABRT, SIG_DFL); + + // Remove any signal masks applied by sagent. + sigset_t mask; + sigemptyset(&mask); + pthread_sigmask(SIG_SETMASK, &mask, NULL); + + // Set process group id (so the terminal doesn't send kill signals to forked children). + setpgrp(); + } + + // Applies signal mask to the calling thread. + void mask_signal() + { + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, SIGINT); + sigaddset(&mask, SIGPIPE); + pthread_sigmask(SIG_BLOCK, &mask, NULL); + } + } // namespace util diff --git a/src/util/util.hpp b/src/util/util.hpp index a885ec7..6f971af 100644 --- a/src/util/util.hpp +++ b/src/util/util.hpp @@ -18,6 +18,10 @@ namespace util const std::string realpath(std::string_view path); + void fork_detach(); + + void mask_signal(); + } // namespace util #endif