diff --git a/src/comm/comm_server.hpp b/src/comm/comm_server.hpp index 38297bf7..8036b796 100644 --- a/src/comm/comm_server.hpp +++ b/src/comm/comm_server.hpp @@ -84,6 +84,9 @@ namespace comm // Cleanup any sessions that needs closure. for (auto itr = sessions.begin(); itr != sessions.end();) { + // Check whether the connection expired according to last activity time rules. + itr->check_last_activity_rules(); + if (itr->state == SESSION_STATE::MUST_CLOSE) itr->close(true); diff --git a/src/comm/comm_session.cpp b/src/comm/comm_session.cpp index af6b78f6..e9e97e62 100644 --- a/src/comm/comm_session.cpp +++ b/src/comm/comm_session.cpp @@ -9,6 +9,8 @@ namespace comm { constexpr uint32_t INTERVALMS = 60000; + constexpr uint16_t INACTIVE_TIMEOUT = 120; // Time threshold for verified inactive connections in seconds. + constexpr uint16_t UNVERIFIED_INACTIVE_TIMEOUT = 5; // Time threshold for unverified inactive connections in seconds. comm_session::comm_session( std::string_view host_address, hpws::client &&hpws_client, const bool is_inbound, const uint64_t (&metric_thresholds)[4]) @@ -38,6 +40,7 @@ namespace comm reader_thread = std::thread(&comm_session::reader_loop, this); writer_thread = std::thread(&comm_session::process_outbound_msg_queue, this); state = SESSION_STATE::ACTIVE; + last_activity_timestamp = util::get_epoch_milliseconds(); } } @@ -59,6 +62,9 @@ namespace comm } else { + // Update last activity timestamp since this session received a message. + last_activity_timestamp = util::get_epoch_milliseconds(); + // Enqueue the message for processing. std::string_view data = std::get(read_result); std::vector msg(data.size()); @@ -124,6 +130,9 @@ namespace comm if (state == SESSION_STATE::CLOSED) return -1; + // Updating last activity timestamp since this session is sending a message. + last_activity_timestamp = util::get_epoch_milliseconds(); + // Passing the ownership of message to the queue. out_msg_queue.enqueue(std::string(message)); return 0; @@ -274,6 +283,20 @@ namespace comm } } + /** + * Check whether the connection expires according to last activity time rules and then mark for closure. + */ + void comm_session::check_last_activity_rules() + { + const uint16_t timeout_seconds = (challenge_status == CHALLENGE_STATUS::CHALLENGE_VERIFIED ? INACTIVE_TIMEOUT : UNVERIFIED_INACTIVE_TIMEOUT); + + if (util::get_epoch_milliseconds() - last_activity_timestamp >= (timeout_seconds * 1000)) + { + LOG_DEBUG << "Closing " << display_name() << " connection due to inactivity."; + mark_for_closure(); + } + } + void comm_session::handle_connect() { } diff --git a/src/comm/comm_session.hpp b/src/comm/comm_session.hpp index d63d5676..7a948fe5 100644 --- a/src/comm/comm_session.hpp +++ b/src/comm/comm_session.hpp @@ -52,6 +52,7 @@ namespace comm std::string issued_challenge; SESSION_STATE state = SESSION_STATE::NONE; CHALLENGE_STATUS challenge_status = CHALLENGE_STATUS::NOT_ISSUED; + uint64_t last_activity_timestamp; // Keep track of the last activity timestamp in milliseconds. comm_session( std::string_view host_address, hpws::client &&hpws_client, const bool is_inbound, const uint64_t (&metric_thresholds)[4]); @@ -61,6 +62,7 @@ namespace comm int send(std::string_view message); int process_outbound_message(std::string_view message); void process_outbound_msg_queue(); + void check_last_activity_rules(); void mark_for_closure(); void close(const bool invoke_handler = true); virtual const std::string display_name();