diff --git a/src/comm/comm_server.hpp b/src/comm/comm_server.hpp index b9f5b970..c4c195ee 100644 --- a/src/comm/comm_server.hpp +++ b/src/comm/comm_server.hpp @@ -24,6 +24,7 @@ namespace comm const uint64_t max_in_connections; const uint64_t max_in_connections_per_host; bool use_priority_queues = false; // Whether to activate inbound message high vs low priority-based processing. + const bool visa_required = false; // Whether clients need to send visa requests to connect. bool is_shutting_down = false; std::list sessions; std::list new_sessions; // Sessions that haven't been initialized properly which are yet to be merge to "sessions" list. @@ -192,7 +193,7 @@ namespace comm // loop will take care of initialize the new sessions. This is because inherited classes (eg. peer_comm_server) // need a way to safely inject new sessions from another thread. std::scoped_lock lock(new_sessions_mutex); - new_sessions.emplace_back(this->violation_tracker, host_address, std::move(client), client.is_ipv4, true, metric_thresholds); + new_sessions.emplace_back(this->violation_tracker, host_address, std::move(client), client.is_ipv4, true, metric_thresholds, !visa_required); return 1; } @@ -266,7 +267,9 @@ namespace comm conf::ctx.tls_cert_file, conf::ctx.tls_key_file, {}, - util::fork_detach); + util::fork_detach, + false, + visa_required ? conf::cfg.contract.id : std::optional{}); if (std::holds_alternative(result)) { @@ -284,12 +287,13 @@ namespace comm corebill::tracker violation_tracker; comm_server(std::string_view name, const uint16_t port, const uint64_t (&metric_thresholds)[5], const uint64_t max_msg_size, - const uint64_t max_in_connections, const uint64_t max_in_connections_per_host, const bool use_priority_queues) + const uint64_t max_in_connections, const uint64_t max_in_connections_per_host, const bool use_priority_queues, const bool visa_required) : metric_thresholds(metric_thresholds), max_msg_size(max_msg_size > 0 ? max_msg_size : DEFAULT_MAX_MSG_SIZE), max_in_connections(max_in_connections > 0 ? max_in_connections : DEFAULT_MAX_CONNECTIONS), max_in_connections_per_host(max_in_connections_per_host > 0 ? max_in_connections_per_host : DEFAULT_MAX_CONNECTIONS), use_priority_queues(use_priority_queues), + visa_required(visa_required), name(name), listen_port(port) { diff --git a/src/comm/comm_session.cpp b/src/comm/comm_session.cpp index 8eee1f37..cd6738b4 100644 --- a/src/comm/comm_session.cpp +++ b/src/comm/comm_session.cpp @@ -13,8 +13,9 @@ namespace comm constexpr uint16_t MAX_IN_MSG_QUEUE_SIZE = 255; // Maximum in message queue size, The size passed is rounded to next number in binary sequence 1(1),11(3),111(7),1111(15),11111(31).... comm_session::comm_session(corebill::tracker &violation_tracker, - std::string_view host_address, hpws::client &&hpws_client, const bool is_ipv4, const bool is_inbound, const uint64_t (&metric_thresholds)[5]) + std::string_view host_address, hpws::client &&hpws_client, const bool is_ipv4, const bool is_inbound, const uint64_t (&metric_thresholds)[5], const bool corebill_enabled) : violation_tracker(violation_tracker), + corebill_enabled(corebill_enabled), hpws_client(std::move(hpws_client)), in_msg_queue1(MAX_IN_MSG_QUEUE_SIZE), in_msg_queue2(MAX_IN_MSG_QUEUE_SIZE), @@ -168,7 +169,7 @@ namespace comm * @param message Message to be added to the outbound queue. * @param priority If 1 adds to high priority queue. Else adds to low priority queue. * @return 0 on successful addition and -1 if the session is already closed. - */ + */ int comm_session::send(const std::vector &message, const uint16_t priority) { std::string_view sv(reinterpret_cast(message.data()), message.size()); @@ -180,7 +181,7 @@ namespace comm * @param message Message to be added to the outbound queue. * @param priority If 1 adds to high priority queue. Else adds to low priority queue. * @return 0 on successful addition and -1 if the session is already closed. - */ + */ int comm_session::send(std::string_view message, const uint16_t priority) { if (state == SESSION_STATE::CLOSED) @@ -202,7 +203,7 @@ namespace comm * This function constructs and sends the message to the target from the given message. * @param message Message to be sent via the pipe. * @return 0 on successful message sent and -1 on error. - */ + */ int comm_session::process_outbound_message(std::string_view message) { if (state == SESSION_STATE::CLOSED || !hpws_client) @@ -219,7 +220,7 @@ namespace comm /** * Process message sending in the queue in the outbound_queue_thread. - */ + */ void comm_session::process_outbound_msg_queue() { // Appling a signal mask to prevent receiving control signals from linux kernel. @@ -263,7 +264,7 @@ namespace comm state = SESSION_STATE::MUST_CLOSE; - if (reason != CLOSE_VIOLATION::VIOLATION_NONE) + if (corebill_enabled && reason != CLOSE_VIOLATION::VIOLATION_NONE) violation_tracker.report_violation(host_address, is_ipv4, std::to_string(reason)); } @@ -310,7 +311,7 @@ namespace comm /** * Set thresholds to the socket session - */ + */ void comm_session::set_threshold(const SESSION_THRESHOLDS threshold_type, const uint64_t threshold_limit, const uint32_t intervalms) { session_threshold &t = thresholds[threshold_type]; @@ -320,9 +321,9 @@ namespace comm } /* - * Increment the provided thresholds counter value with the provided amount and validate it against the - * configured threshold limit. - */ + * Increment the provided thresholds counter value with the provided amount and validate it against the + * configured threshold limit. + */ void comm_session::increment_metric(const SESSION_THRESHOLDS threshold_type, const uint64_t amount) { session_threshold &t = thresholds[threshold_type]; @@ -360,7 +361,7 @@ namespace comm /** * Check whether the connection expires according to last activity time rules and then mark for closure. - */ + */ void comm_session::check_last_activity_rules() { const uint32_t timeout = (challenge_status == CHALLENGE_STATUS::CHALLENGE_VERIFIED ? thresholds[SESSION_THRESHOLDS::IDLE_CONNECTION_TIMEOUT].threshold_limit : UNVERIFIED_INACTIVE_TIMEOUT); @@ -378,7 +379,7 @@ namespace comm /** * Mark the connection as a verified connection. - */ + */ void comm_session::mark_as_verified() { challenge_status = CHALLENGE_STATUS::CHALLENGE_VERIFIED; diff --git a/src/comm/comm_session.hpp b/src/comm/comm_session.hpp index 84aa1fa2..07bd2289 100644 --- a/src/comm/comm_session.hpp +++ b/src/comm/comm_session.hpp @@ -41,6 +41,7 @@ namespace comm { private: corebill::tracker &violation_tracker; + const bool corebill_enabled; // Wether corebill enabled for the session. std::optional hpws_client; std::vector thresholds; // track down various communication thresholds @@ -72,7 +73,7 @@ namespace comm uint64_t last_activity_timestamp; // Keep track of the last activity timestamp in milliseconds. comm_session(corebill::tracker &violation_tracker, - std::string_view host_address, hpws::client &&hpws_client, const bool is_ipv4, const bool is_inbound, const uint64_t (&metric_thresholds)[5]); + std::string_view host_address, hpws::client &&hpws_client, const bool is_ipv4, const bool is_inbound, const uint64_t (&metric_thresholds)[5], const bool corebill_enabled); int init(); int process_next_inbound_message(const uint16_t priority); int send(const std::vector &message, const uint16_t priority = 2); diff --git a/src/comm/hpws.hpp b/src/comm/hpws.hpp index bc73eaab..3a4247d8 100644 --- a/src/comm/hpws.hpp +++ b/src/comm/hpws.hpp @@ -48,6 +48,7 @@ namespace hpws #define HPWS_SMALL_TIMEOUT 10 // used when waiting for server process to spawn #define HPWS_LONG_TIMEOUT 1500 // This timeout has to account the possible delays in communication via internet. +#define HPWS_VISA_TIMEOUT 5000 // This timeout has to account the possible delays in visa pow calculation. typedef union { @@ -322,7 +323,10 @@ namespace hpws uint16_t port, std::string_view get, std::vector argv, - std::function fork_child_init = NULL) + std::function fork_child_init = NULL, + bool is_ipv4 = false, + std::optional visa_token = {}, + std::function parent_terminated = NULL) { #define HPWS_CONNECT_ERROR(code, msg) \ @@ -338,7 +342,7 @@ namespace hpws int buffer_fd[4] = {-1, -1, -1, -1}; void *mapping[4] = {NULL, NULL, NULL, NULL}; int pid = -1; - int count_args = 14 + argv.size(); + int count_args = 14 + argv.size() + (visa_token.has_value() ? 2 : 0) + (is_ipv4 ? 1 : 0); char const **argv_pass = NULL; if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd)) @@ -381,6 +385,13 @@ namespace hpws argv_pass[upto++] = cfd2; argv_pass[upto++] = "--get"; argv_pass[upto++] = get.data(); + if (is_ipv4) + argv_pass[upto++] = "--ipv4"; + if (visa_token.has_value()) + { + argv_pass[upto++] = "--visatoken"; + argv_pass[upto++] = visa_token.value().data(); + } for (std::string_view &arg : argv) argv_pass[upto++] = arg.data(); argv_pass[upto] = NULL; @@ -415,11 +426,27 @@ namespace hpws // we will set a timeout and wait for the initial startup message from hpws client mode struct pollfd pfd; - int ret; + int ret = 0; pfd.fd = child_fd[0]; // we receive setup events on control line 0 (hpws->hpcore) pfd.events = POLLIN; - ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + if (!visa_token.has_value()) + ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + else + { + // If the signals are blocked in the caller we handle termination from outside. + // We wait in a loop because visa timeout could be long otherwise termination hangs. + uint32_t timer = 0; + while (ret == 0 && timer < HPWS_VISA_TIMEOUT) + { + // Break if connection is terminated from the server. + if (parent_terminated && parent_terminated()) + break; + + ret = poll(&pfd, 1, 1000); + timer += 1000; + } + } // timeout or error if (ret < 1) @@ -454,7 +481,7 @@ namespace hpws memcpy(&buffer_fd, CMSG_DATA(cmsg), sizeof(buffer_fd)); for (int i = 0; i < 4; ++i) { - //fprintf(stderr, "scm passed buffer_fd[%d] = %d\n", i, buffer_fd[i]); + // fprintf(stderr, "scm passed buffer_fd[%d] = %d\n", i, buffer_fd[i]); if (buffer_fd[i] < 0) HPWS_CONNECT_ERROR(203, "child accept scm_rights a passed buffer fd was negative"); mapping[i] = @@ -748,6 +775,10 @@ namespace hpws if (ret < 1) HPWS_ACCEPT_ERROR(202, "timeout waiting for hpws accept child message"); + // check whether write end in forcefully closed + if (pfd.revents & POLLHUP) + HPWS_ACCEPT_ERROR(205, "child connection closed"); + // first thing we'll receive is the pid of the client if (recv(child_fd[0], (unsigned char *)(&pid), sizeof(pid), 0) < (ssize_t)sizeof(pid)) HPWS_ACCEPT_ERROR(212, "did not receive expected 4 byte pid of child process on accept"); @@ -788,7 +819,7 @@ namespace hpws for (int i = 0; i < 4; ++i) { - //fprintf(stderr, "scm passed buffer_fd[%d] = %d\n", i, buffer_fd[i]); + // fprintf(stderr, "scm passed buffer_fd[%d] = %d\n", i, buffer_fd[i]); if (buffer_fd[i] < 0) HPWS_ACCEPT_ERROR(203, "child accept scm_rights a passed buffer fd was negative"); mapping[i] = @@ -876,8 +907,10 @@ namespace hpws uint16_t max_con_per_ip, std::string_view cert_path, std::string_view key_path, - std::vector argv, //additional_arguments - std::function fork_child_init = NULL) + std::vector argv, // additional_arguments + std::function fork_child_init = NULL, + bool is_ipv6 = false, + std::optional visa_token = {}) { #define HPWS_SERVER_ERROR(code, msg) \ { \ @@ -890,7 +923,7 @@ namespace hpws const char *error_msg = NULL; int fd[2] = {-1, -1}; pid_t pid = -1; - int count_args = 17 + argv.size(); + int count_args = 16 + argv.size() + (is_ipv6 ? 1 : 0) + (visa_token.has_value() ? 2 : 0); char const **argv_pass = NULL; if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd)) @@ -934,6 +967,13 @@ namespace hpws argv_pass[upto++] = max_con_str; argv_pass[upto++] = "--maxconip"; argv_pass[upto++] = max_con_per_ip_str; + if (is_ipv6) + argv_pass[upto++] = "--ipv6"; + if (visa_token.has_value()) + { + argv_pass[upto++] = "--visatoken"; + argv_pass[upto++] = visa_token.value().data(); + } for (std::string_view &arg : argv) argv_pass[upto++] = arg.data(); argv_pass[upto] = NULL; diff --git a/src/p2p/peer_comm_server.cpp b/src/p2p/peer_comm_server.cpp index e8c92996..ea24b85f 100644 --- a/src/p2p/peer_comm_server.cpp +++ b/src/p2p/peer_comm_server.cpp @@ -18,7 +18,7 @@ namespace p2p peer_comm_server::peer_comm_server(const uint16_t port, const uint64_t (&metric_thresholds)[5], const uint64_t max_msg_size, const uint64_t max_in_connections, const uint64_t max_in_connections_per_host, const std::vector &req_known_remotes) - : comm::comm_server("Peer", port, metric_thresholds, max_msg_size, max_in_connections, max_in_connections_per_host, true), + : comm::comm_server("Peer", port, metric_thresholds, max_msg_size, max_in_connections, max_in_connections_per_host, true, true), req_known_remotes(req_known_remotes) // Copy over known peers into internal collection. { } @@ -171,7 +171,8 @@ namespace p2p const uint16_t port = peer.ip_port.port; LOG_DEBUG << "Trying to connect " << host << ":" << std::to_string(port); - std::variant client_result = hpws::client::connect(conf::ctx.hpws_exe_path, max_msg_size, host, port, "/", {}, util::fork_detach); + std::variant client_result = hpws::client::connect(conf::ctx.hpws_exe_path, max_msg_size, host, port, "/", {}, util::fork_detach, true, conf::cfg.contract.id, [&]() -> bool + { return is_shutting_down; }); if (std::holds_alternative(client_result)) { @@ -198,7 +199,7 @@ namespace p2p else { const std::string &host_address = std::get(host_result); - p2p::peer_comm_session session(this->violation_tracker, host_address, std::move(client), client.is_ipv4, false, metric_thresholds); + p2p::peer_comm_session session(this->violation_tracker, host_address, std::move(client), client.is_ipv4, false, metric_thresholds, !visa_required); // Skip if this peer is banned due to corebill violations. if (violation_tracker.is_banned(host_address)) diff --git a/src/usr/usr.cpp b/src/usr/usr.cpp index 91bd384a..061f48e8 100644 --- a/src/usr/usr.cpp +++ b/src/usr/usr.cpp @@ -80,7 +80,7 @@ namespace usr int start_listening() { ctx.server.emplace("User", conf::cfg.user.port, metric_thresholds, conf::cfg.user.max_bytes_per_msg, - conf::cfg.user.max_connections, conf::cfg.user.max_in_connections_per_host, false); + conf::cfg.user.max_connections, conf::cfg.user.max_in_connections_per_host, false, false); if (ctx.server->start() == -1) return -1; diff --git a/test/bin/hpws b/test/bin/hpws index 91634d23..aa5e24d5 100755 Binary files a/test/bin/hpws and b/test/bin/hpws differ