diff --git a/src/cluster/ClusterCommunicationService.cpp b/src/cluster/ClusterCommunicationService.cpp index e2b87ba9..f714549c 100644 --- a/src/cluster/ClusterCommunicationService.cpp +++ b/src/cluster/ClusterCommunicationService.cpp @@ -21,9 +21,13 @@ #include "cluster/ClioNode.hpp" #include "data/BackendInterface.hpp" +#include "util/Assert.hpp" #include "util/Spawn.hpp" #include "util/log/Logger.hpp" +#include +#include +#include #include #include #include @@ -36,11 +40,16 @@ #include #include +#include #include #include #include #include +namespace { +constexpr auto kTOTAL_WORKERS = 2uz; // 1 reading and 1 writing worker (coroutines) +} // namespace + namespace cluster { ClusterCommunicationService::ClusterCommunicationService( @@ -51,6 +60,7 @@ ClusterCommunicationService::ClusterCommunicationService( : backend_(std::move(backend)) , readInterval_(readInterval) , writeInterval_(writeInterval) + , finishedCountdown_(kTOTAL_WORKERS) , selfData_{ClioNode{ .uuid = std::make_shared(boost::uuids::random_generator{}()), .updateTime = std::chrono::system_clock::time_point{} @@ -63,22 +73,42 @@ ClusterCommunicationService::ClusterCommunicationService( void ClusterCommunicationService::run() { + ASSERT(not running_ and not stopped_, "Can only be ran once"); + running_ = true; + util::spawn(strand_, [this](boost::asio::yield_context yield) { boost::asio::steady_timer timer(yield.get_executor()); - while (true) { + boost::system::error_code ec; + + while (running_) { timer.expires_after(readInterval_); - timer.async_wait(yield); + auto token = cancelSignal_.slot(); + timer.async_wait(boost::asio::bind_cancellation_slot(token, yield[ec])); + + if (ec == boost::asio::error::operation_aborted or not running_) + break; + doRead(yield); } + + finishedCountdown_.count_down(1); }); util::spawn(strand_, [this](boost::asio::yield_context yield) { boost::asio::steady_timer timer(yield.get_executor()); - while (true) { + boost::system::error_code ec; + + while (running_) { doWrite(); timer.expires_after(writeInterval_); - timer.async_wait(yield); + auto token = cancelSignal_.slot(); + timer.async_wait(boost::asio::bind_cancellation_slot(token, yield[ec])); + + if (ec == boost::asio::error::operation_aborted or not running_) + break; } + + finishedCountdown_.count_down(1); }); } @@ -93,9 +123,14 @@ ClusterCommunicationService::stop() if (stopped_) return; - ctx_.stop(); - ctx_.join(); stopped_ = true; + + // for ASAN to see through concurrency correctly we need to exit all coroutines before joining the ctx + running_ = false; + cancelSignal_.emit(boost::asio::cancellation_type::all); + finishedCountdown_.wait(); + + ctx_.join(); } std::shared_ptr diff --git a/src/cluster/ClusterCommunicationService.hpp b/src/cluster/ClusterCommunicationService.hpp index 0b196088..3814271e 100644 --- a/src/cluster/ClusterCommunicationService.hpp +++ b/src/cluster/ClusterCommunicationService.hpp @@ -27,12 +27,15 @@ #include "util/prometheus/Gauge.hpp" #include "util/prometheus/Prometheus.hpp" +#include #include #include #include #include +#include #include +#include #include #include #include @@ -65,11 +68,14 @@ class ClusterCommunicationService : public ClusterCommunicationServiceInterface std::chrono::steady_clock::duration readInterval_; std::chrono::steady_clock::duration writeInterval_; + boost::asio::cancellation_signal cancelSignal_; + std::latch finishedCountdown_; + std::atomic_bool running_ = false; + bool stopped_ = false; + ClioNode selfData_; std::vector otherNodesData_; - bool stopped_ = false; - public: static constexpr std::chrono::milliseconds kDEFAULT_READ_INTERVAL{2100}; static constexpr std::chrono::milliseconds kDEFAULT_WRITE_INTERVAL{1200}; diff --git a/tests/unit/cluster/ClusterCommunicationServiceTests.cpp b/tests/unit/cluster/ClusterCommunicationServiceTests.cpp index bef14bfa..544d7e3d 100644 --- a/tests/unit/cluster/ClusterCommunicationServiceTests.cpp +++ b/tests/unit/cluster/ClusterCommunicationServiceTests.cpp @@ -49,6 +49,19 @@ using namespace cluster; +namespace { +std::vector const kOTHER_NODES_DATA = { + ClioNode{ + .uuid = std::make_shared(boost::uuids::random_generator()()), + .updateTime = util::systemTpFromUtcStr("2015-05-15T12:00:00Z", ClioNode::kTIME_FORMAT).value() + }, + ClioNode{ + .uuid = std::make_shared(boost::uuids::random_generator()()), + .updateTime = util::systemTpFromUtcStr("2015-05-15T12:00:01Z", ClioNode::kTIME_FORMAT).value() + }, +}; +} // namespace + struct ClusterCommunicationServiceTest : util::prometheus::WithPrometheus, MockBackendTestStrict { ClusterCommunicationService clusterCommunicationService{ backend_, @@ -97,6 +110,7 @@ TEST_F(ClusterCommunicationServiceTest, Write) clusterCommunicationService.run(); wait(); + // destructor of clusterCommunicationService calls .stop() } TEST_F(ClusterCommunicationServiceTest, Read_FetchFailed) @@ -105,10 +119,13 @@ TEST_F(ClusterCommunicationServiceTest, Read_FetchFailed) EXPECT_CALL(*backend_, writeNodeMessage).Times(2).WillOnce([](auto&&, auto&&) {}).WillOnce([this](auto&&, auto&&) { notify(); }); - EXPECT_CALL(*backend_, fetchClioNodesData).WillOnce([](auto&&) { return std::unexpected{"Failed"}; }); + EXPECT_CALL(*backend_, fetchClioNodesData).WillRepeatedly([](auto&&) { return std::unexpected{"Failed"}; }); clusterCommunicationService.run(); wait(); + // call .stop() manually so that workers exit before expectations are called more times than we want + clusterCommunicationService.stop(); + EXPECT_FALSE(isHealthyMetric); } @@ -118,10 +135,12 @@ TEST_F(ClusterCommunicationServiceTest, Read_FetchThrew) EXPECT_CALL(*backend_, writeNodeMessage).Times(2).WillOnce([](auto&&, auto&&) {}).WillOnce([this](auto&&, auto&&) { notify(); }); - EXPECT_CALL(*backend_, fetchClioNodesData).WillOnce(testing::Throw(data::DatabaseTimeout{})); + EXPECT_CALL(*backend_, fetchClioNodesData).WillRepeatedly(testing::Throw(data::DatabaseTimeout{})); clusterCommunicationService.run(); wait(); + clusterCommunicationService.stop(); + EXPECT_FALSE(isHealthyMetric); EXPECT_FALSE(clusterCommunicationService.clusterData().has_value()); } @@ -132,7 +151,7 @@ TEST_F(ClusterCommunicationServiceTest, Read_GotInvalidJson) EXPECT_CALL(*backend_, writeNodeMessage).Times(2).WillOnce([](auto&&, auto&&) {}).WillOnce([this](auto&&, auto&&) { notify(); }); - EXPECT_CALL(*backend_, fetchClioNodesData).WillOnce([](auto&&) { + EXPECT_CALL(*backend_, fetchClioNodesData).WillRepeatedly([](auto&&) { return std::vector>{ {boost::uuids::random_generator()(), "invalid json"} }; @@ -140,6 +159,8 @@ TEST_F(ClusterCommunicationServiceTest, Read_GotInvalidJson) clusterCommunicationService.run(); wait(); + clusterCommunicationService.stop(); + EXPECT_FALSE(isHealthyMetric); EXPECT_FALSE(clusterCommunicationService.clusterData().has_value()); } @@ -150,12 +171,14 @@ TEST_F(ClusterCommunicationServiceTest, Read_GotInvalidNodeData) EXPECT_CALL(*backend_, writeNodeMessage).Times(2).WillOnce([](auto&&, auto&&) {}).WillOnce([this](auto&&, auto&&) { notify(); }); - EXPECT_CALL(*backend_, fetchClioNodesData).WillOnce([](auto&&) { + EXPECT_CALL(*backend_, fetchClioNodesData).WillRepeatedly([](auto&&) { return std::vector>{{boost::uuids::random_generator()(), "{}"}}; }); clusterCommunicationService.run(); wait(); + clusterCommunicationService.stop(); + EXPECT_FALSE(isHealthyMetric); EXPECT_FALSE(clusterCommunicationService.clusterData().has_value()); } @@ -164,23 +187,12 @@ TEST_F(ClusterCommunicationServiceTest, Read_Success) { EXPECT_TRUE(isHealthyMetric); EXPECT_EQ(nodesInClusterMetric.value(), 1); - std::vector otherNodesData = { - ClioNode{ - .uuid = std::make_shared(boost::uuids::random_generator()()), - .updateTime = util::systemTpFromUtcStr("2015-05-15T12:00:00Z", ClioNode::kTIME_FORMAT).value() - }, - ClioNode{ - .uuid = std::make_shared(boost::uuids::random_generator()()), - .updateTime = util::systemTpFromUtcStr("2015-05-15T12:00:01Z", ClioNode::kTIME_FORMAT).value() - }, - }; - auto const selfUuid = *clusterCommunicationService.selfUuid(); - EXPECT_CALL(*backend_, writeNodeMessage).Times(2).WillOnce([](auto&&, auto&&) {}).WillOnce([&](auto&&, auto&&) { + EXPECT_CALL(*backend_, writeNodeMessage).Times(2).WillOnce([](auto&&, auto&&) {}).WillOnce([this](auto&&, auto&&) { auto const clusterData = clusterCommunicationService.clusterData(); ASSERT_TRUE(clusterData.has_value()); - ASSERT_EQ(clusterData->size(), otherNodesData.size() + 1); - for (auto const& node : otherNodesData) { + ASSERT_EQ(clusterData->size(), kOTHER_NODES_DATA.size() + 1); + for (auto const& node : kOTHER_NODES_DATA) { auto const it = std::ranges::find_if(*clusterData, [&](ClioNode const& n) { return *(n.uuid) == *(node.uuid); }); EXPECT_NE(it, clusterData->cend()) << boost::uuids::to_string(*node.uuid); @@ -193,12 +205,13 @@ TEST_F(ClusterCommunicationServiceTest, Read_Success) notify(); }); - EXPECT_CALL(*backend_, fetchClioNodesData).WillOnce([&](auto&&) { + EXPECT_CALL(*backend_, fetchClioNodesData).WillRepeatedly([this](auto&&) { + auto const selfUuid = clusterCommunicationService.selfUuid(); std::vector> result = { - {selfUuid, R"JSON({"update_time": "2015-05-15:12:00:00"})JSON"}, + {*selfUuid, R"JSON({"update_time": "2015-05-15:12:00:00"})JSON"}, }; - for (auto const& node : otherNodesData) { + for (auto const& node : kOTHER_NODES_DATA) { boost::json::value jsonValue; boost::json::value_from(node, jsonValue); result.emplace_back(*node.uuid, boost::json::serialize(jsonValue)); @@ -208,6 +221,8 @@ TEST_F(ClusterCommunicationServiceTest, Read_Success) clusterCommunicationService.run(); wait(); + clusterCommunicationService.stop(); + EXPECT_TRUE(isHealthyMetric); EXPECT_EQ(nodesInClusterMetric.value(), 3); }