diff --git a/CMakeLists.txt b/CMakeLists.txt index 32acdda5..9c85d1cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,11 @@ target_link_libraries(hpsupport ssl ) +add_library(hpbill + src/bill/corebill.cpp +) +target_link_libraries(hpbill hpsupport) + add_library(hpsock src/sock/socket_client.cpp src/sock/socket_server.cpp @@ -46,7 +51,7 @@ add_library(hpsock src/sock/socket_session.cpp src/sock/socket_session_lambda.cpp ) -target_link_libraries(hpsock hpsupport) +target_link_libraries(hpsock hpsupport hpbill) add_library(hpschema src/fbschema/common_helpers.cpp @@ -79,6 +84,7 @@ add_executable(hpcore ) target_link_libraries(hpcore hpsupport + hpbill hpschema hpsock hpp2p diff --git a/src/bill/corebill.cpp b/src/bill/corebill.cpp new file mode 100644 index 00000000..9847c5bc --- /dev/null +++ b/src/bill/corebill.cpp @@ -0,0 +1,87 @@ +#include "../pchheader.hpp" +#include "../util.hpp" +#include "../hplog.hpp" +#include "corebill.h" + +namespace corebill +{ + +// How many violations can occur for a host before being escalated. +static const uint32_t VIOLATION_THRESHOLD = 10; + +// Violation cooldown interval. +static const uint32_t VIOLATION_REFRESH_INTERVAL = 600 * 1000; // 10 minutes + +// Keeps track of violation count against offending hosts. +std::unordered_map violation_counter; + +// Keeps the graylisted hosts. +util::ttl_set graylist; + +// Keeps the whitelisted hosts who would be ignored in all violation tracking. +std::unordered_set whitelist; + +/** + * Report a violation. Violation means a force disconnection of a socket due to some threshold exceeding. + */ +void report_violation(const std::string host) +{ + if (whitelist.find(host) != whitelist.end()) // Is in whitelist + { + LOG_DBG << host << " is whitelisted. Ignoring the violation."; + return; + } + + violation_stat &stat = violation_counter[host]; + + const uint64_t time_now = util::get_epoch_milliseconds(); + + stat.counter++; + + if (stat.timestamp == 0) + { + // Reset counter timestamp. + stat.timestamp = time_now; + } + else + { + // Check whether we have exceeded the threshold within the monitering interval. + const uint64_t elapsed_time = time_now - stat.timestamp; + if (elapsed_time <= VIOLATION_REFRESH_INTERVAL && stat.counter > VIOLATION_THRESHOLD) + { + // IP exceeded violation threshold. + + stat.timestamp = 0; + stat.counter = 0; + + graylist.emplace(host, VIOLATION_REFRESH_INTERVAL); + LOG_WARN << host << " placed on graylist."; + } + else if (elapsed_time > VIOLATION_REFRESH_INTERVAL) + { + // Start the counter fresh. + stat.timestamp = time_now; + stat.counter = 1; + } + } +} + +void add_to_whitelist(const std::string host) +{ + // Add to whitelist and remove from all other offender lists. + whitelist.emplace(host); + graylist.erase(host); + violation_counter.erase(host); +} + +void remove_from_whitelist(const std::string host) +{ + whitelist.erase(host); +} + +bool is_banned(const std::string &host) +{ + return graylist.exists(host); +} + +} // namespace corebill \ No newline at end of file diff --git a/src/bill/corebill.h b/src/bill/corebill.h new file mode 100644 index 00000000..07f37a2c --- /dev/null +++ b/src/bill/corebill.h @@ -0,0 +1,24 @@ +#ifndef _HP_COREBILL_ +#define _HP_COREBILL_ + +#include "../pchheader.hpp" + +namespace corebill +{ + +/** + * Keeps the violation counter and the timestamp of the monitoring window. + */ +struct violation_stat +{ + uint32_t counter; + uint64_t timestamp; +}; + +void report_violation(const std::string host); +void add_to_whitelist(const std::string host); +bool is_banned(const std::string &host); + +} // namespace corebill + +#endif \ No newline at end of file diff --git a/src/sock/socket_server.cpp b/src/sock/socket_server.cpp index c63242cd..a87fabb5 100644 --- a/src/sock/socket_server.cpp +++ b/src/sock/socket_server.cpp @@ -93,7 +93,7 @@ void socket_server::on_accept(error_code ec, tcp::socket socket) // Launch a new session for this connection std::make_shared>(std::move(ws), sess_handler) - ->run(std::move(port), std::move(address), true, sess_opts); + ->run(std::move(address), std::move(port), true, sess_opts); } // Accept another connection diff --git a/src/sock/socket_session.cpp b/src/sock/socket_session.cpp index 6ccac347..759144d6 100644 --- a/src/sock/socket_session.cpp +++ b/src/sock/socket_session.cpp @@ -1,3 +1,4 @@ +#include "../bill/corebill.h" #include "socket_session.hpp" #include "socket_message.hpp" #include "socket_session_handler.hpp" @@ -64,11 +65,14 @@ void socket_session::increment_metric(const SESSION_THRESHOLDS threshold_type const uint64_t elapsed_time = time_now - t.timestamp; if (elapsed_time <= t.intervalms && t.counter_value > t.threshold_limit) { + this->close(); + t.timestamp = 0; t.counter_value = 0; LOG_INFO << "Session " << this->uniqueid << " threshold exceeded. (type:" << threshold_type << " limit:" << t.threshold_limit << ")"; - this->close(); + corebill::report_violation(this->address); + } else if (elapsed_time > t.intervalms) { @@ -165,6 +169,12 @@ void socket_session::on_accept(const error_code ec) if (ec) return fail(ec, "accept"); + if (corebill::is_banned(this->address)) + { + LOG_DBG << "Dropping connection for banned host " << this->address; + this->close(); + } + sess_handler.on_connect(this); // Read a message diff --git a/src/usr/user_session_handler.cpp b/src/usr/user_session_handler.cpp index e2d60d05..47915f30 100644 --- a/src/usr/user_session_handler.cpp +++ b/src/usr/user_session_handler.cpp @@ -3,6 +3,7 @@ #include "../jsonschema/usrmsg_helpers.hpp" #include "../sock/socket_session.hpp" #include "../sock/socket_message.hpp" +#include "../bill/corebill.h" #include "usr.hpp" #include "user_session_handler.hpp" @@ -17,7 +18,7 @@ namespace usr */ void user_session_handler::on_connect(sock::socket_session *session) { - LOG_INFO << "User client connected " << session->address << ":" << session->port; + LOG_DBG << "User client connected " << session->address << ":" << session->port; // As soon as a user connects, we issue them a challenge message. We remember the // challenge we issued and later verifies the user's response with it. @@ -71,7 +72,8 @@ void user_session_handler::on_message( // If for any reason we reach this point, we should drop the connection because none of the // valid cases match. session->close(); - LOG_INFO << "Dropped the user connection " << session->address << ":" << session->port; + LOG_DBG << "Dropped the user connection " << session->uniqueid; + corebill::report_violation(session->address); } /** @@ -89,7 +91,7 @@ void user_session_handler::on_close(sock::socket_session else if (session->flags[sock::SESSION_FLAG::USER_AUTHED]) remove_user(session->uniqueid); - LOG_INFO << "User disconnected " << session->uniqueid; + LOG_DBG << "User disconnected " << session->uniqueid; } } // namespace usr \ No newline at end of file diff --git a/src/usr/usr.cpp b/src/usr/usr.cpp index c0d47ff2..9616916a 100644 --- a/src/usr/usr.cpp +++ b/src/usr/usr.cpp @@ -87,18 +87,18 @@ int verify_challenge(std::string_view message, sock::socket_sessionuniqueid); // Remove the stored challenge - LOG_INFO << "User connection " << session->uniqueid << " authenticated. Public key " + LOG_DBG << "User connection " << session->uniqueid << " authenticated. Public key " << userpubkeyhex; return 0; } else { - LOG_INFO << "Duplicate user public key " << session->uniqueid; + LOG_DBG << "Duplicate user public key " << session->uniqueid; } } else { - LOG_INFO << "Challenge verification failed " << session->uniqueid; + LOG_DBG << "Challenge verification failed " << session->uniqueid; } return -1; diff --git a/src/util.cpp b/src/util.cpp index 2d8c4bbe..50dffdc3 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -40,6 +40,44 @@ bool rollover_hashset::try_emplace(const std::string hash) return false; // Hash already exists. } +// ttl_set class methods. + +/** + * If key does not exist, inserts it with the specified ttl. If key exists, + * renews the expiration time to match the time-to-live from now onwards. + * @param key Object to insert. + * @param ttl Time to live in milliseonds. + */ +void ttl_set::emplace(const std::string key, uint64_t ttl_milli) +{ + ttlmap[key] = util::get_epoch_milliseconds() + ttl_milli; +} + +void ttl_set::erase(const std::string &key) +{ + const auto itr = ttlmap.find(key); + if (itr != ttlmap.end()) + ttlmap.erase(itr); +} + +/** + * Returns true of the key exists and not expired. Returns false if key does not exist + * or has expired. + */ +bool ttl_set::exists(const std::string &key) +{ + const auto itr = ttlmap.find(key); + if (itr == ttlmap.end()) // Not found + return false; + + // Check whether we are passed the expiration time (itr->second is the expiration time) + const bool expired = util::get_epoch_milliseconds() > itr->second; + if (expired) + ttlmap.erase(itr); + + return !expired; +} + /** * Encodes provided bytes to hex string. * @@ -92,8 +130,8 @@ int hex2bin(unsigned char *decodedbuf, const size_t decodedbuf_len, std::string_ int64_t get_epoch_milliseconds() { return std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(); + std::chrono::system_clock::now().time_since_epoch()) + .count(); } /** diff --git a/src/util.hpp b/src/util.hpp index 102b6223..bc72b26e 100644 --- a/src/util.hpp +++ b/src/util.hpp @@ -17,11 +17,11 @@ static const char *HP_VERSION = "0.1"; static const char *MIN_CONTRACT_VERSION = "0.1"; // Current version of the peer message protocol. -static const int PEERMSG_VERSION = 1; +static const uint8_t PEERMSG_VERSION = 1; // Minimum compatible peer message version (this will be used to accept/reject incoming peer connections) // (Keeping this as int for effcient msg payload and comparison) -static const int MIN_PEERMSG_VERSION = 1; +static const uint8_t MIN_PEERMSG_VERSION = 1; /** * FIFO hash set with a max size. @@ -43,6 +43,21 @@ public: bool try_emplace(const std::string hash); }; +/** + * A string set with expiration for elements. + */ +class ttl_set +{ +private: + // Keeps short-lived items with their absolute expiration time. + std::unordered_map ttlmap; + +public: + void emplace(const std::string key, uint64_t ttl_milli); + void erase(const std::string &key); + bool exists(const std::string &key); +}; + int bin2hex(std::string &encoded_string, const unsigned char *bin, const size_t bin_len); int hex2bin(unsigned char *decoded, const size_t decoded_len, std::string_view hex_str);