diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 8b12f764..e5d54f5a 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -4,6 +4,7 @@ target_sources( clio_util PRIVATE Assert.cpp build/Build.cpp + Coroutine.cpp CoroutineGroup.cpp log/Logger.cpp prometheus/Http.cpp diff --git a/src/util/Coroutine.cpp b/src/util/Coroutine.cpp new file mode 100644 index 00000000..9ab6bdb1 --- /dev/null +++ b/src/util/Coroutine.cpp @@ -0,0 +1,90 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "util/Coroutine.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace util { + +Coroutine::Coroutine(boost::asio::yield_context&& yield, std::shared_ptr signal) + : yield_(std::move(yield)) + , cyield_(boost::asio::bind_cancellation_slot(cancellationSignal_.slot(), yield_[error_])) + , familySignal_{std::move(signal)} + , connection_{familySignal_->connect([this](boost::asio::cancellation_type_t cancellationType) { + cancellationSignal_.emit(cancellationType); + isCancelled_ = true; + })} + +{ +} + +Coroutine::~Coroutine() +{ + connection_.disconnect(); +} + +boost::system::error_code +Coroutine::error() const +{ + return error_; +} + +void +Coroutine::cancelAll(boost::asio::cancellation_type_t cancellationType) +{ + if (isCancelled()) + return; + familySignal_->operator()(cancellationType); +} + +bool +Coroutine::isCancelled() const +{ + return error_ == boost::asio::error::operation_aborted || isCancelled_; +} + +Coroutine::cancellable_yield_context_type +Coroutine::yieldContext() const +{ + return cyield_; +} + +boost::asio::any_io_executor +Coroutine::executor() const +{ + return cyield_.get().get_executor(); +} + +void +Coroutine::yield() const +{ + boost::asio::post(yield_); +} + +} // namespace util diff --git a/src/util/Coroutine.hpp b/src/util/Coroutine.hpp new file mode 100644 index 00000000..cb4317bf --- /dev/null +++ b/src/util/Coroutine.hpp @@ -0,0 +1,191 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace util { + +class Coroutine; + +/** + * @brief Concept for functions that can be used as coroutine bodies. + * Such functions must be invocable with a `Coroutine&` argument. + * @tparam Fn The function type to check. + */ +template +concept CoroutineFunction = std::invocable and not std::is_reference_v; + +/** + * @brief Manages a coroutine execution context, allowing for cooperative multitasking + * and cancellation. + * + * The Coroutine class wraps a Boost.Asio yield_context and provides mechanisms + * for spawning new coroutines, child coroutines, and managing their lifecycle, + * including cancellation. It integrates with a signal system to propagate + * cancellation requests across related coroutines. + */ +class Coroutine { +public: + /** + * @brief Type alias for a yield_context that is bound to a cancellation slot. + * This allows asynchronous operations initiated with this context to be cancelled. + */ + using cancellable_yield_context_type = + boost::asio::cancellation_slot_binder; + +private: + boost::asio::yield_context yield_; + boost::system::error_code error_; + boost::asio::cancellation_signal cancellationSignal_; + cancellable_yield_context_type cyield_; + std::atomic_bool isCancelled_{false}; + + using FamilyCancellationSignal = boost::signals2::signal; + std::shared_ptr familySignal_; + boost::signals2::connection connection_; + + /** + * @brief Private constructor to create a Coroutine instance. + * @param yield The Boost.Asio yield_context for this coroutine. + * @param signal A shared signal used for propagating cancellation requests among related coroutines. + */ + explicit Coroutine( + boost::asio::yield_context&& yield, + std::shared_ptr signal = std::make_shared() + ); + +public: + /** + * @brief Destructor for the Coroutine. + * Handles cleanup, such as disconnecting from the cancellation signal. + */ + ~Coroutine(); + + Coroutine(Coroutine const&) = delete; + Coroutine(Coroutine&&) = delete; + + Coroutine& + operator==(Coroutine&&) = delete; + + Coroutine& + operator==(Coroutine const&) = delete; + + /** + * @brief Spawns a new top-level coroutine. + * @tparam ExecutionContext The type of the I/O execution context (e.g., boost::asio::io_context). + * @tparam Fn The type of the invocable function that represents the coroutine body. + * @param ioContext The I/O execution context on which to spawn the coroutine. + * @param fn The function to be executed as the coroutine. It will receive a Coroutine& argument. + */ + template + static void + spawnNew(ExecutionContext& ioContext, Fn fn) + { + boost::asio::spawn(ioContext, [fn = std::move(fn)](boost::asio::yield_context yield) { + Coroutine thisCoroutine{std::move(yield)}; + fn(thisCoroutine); + }); + } + + /** + * @brief Spawns a child coroutine from this coroutine. + * The child coroutine shares the same cancellation signal. + * @tparam Fn The type of the invocable function that represents the child coroutine body. + * @param fn The function to be executed as the child coroutine. It will receive a Coroutine& argument. + */ + template + void + spawnChild(Fn fn) + { + if (isCancelled_) + return; + + boost::asio::spawn( + yield_, + [signal = familySignal_, fn = std::move(fn)](boost::asio::yield_context yield) mutable { + Coroutine coroutine(std::move(yield), std::move(signal)); + fn(coroutine); + } + ); + } + + /** + * @brief Returns the error code, if any, associated with the last operation in this coroutine. + * @return A boost::system::error_code indicating the status. + */ + [[nodiscard]] boost::system::error_code + error() const; + + /** + * @brief Cancels all coroutines sharing the same root cancellation signal. + * @param cancellationType The type of cancellation to perform. + * Defaults to boost::asio::cancellation_type::terminal. + */ + void + cancelAll(boost::asio::cancellation_type_t cancellationType = boost::asio::cancellation_type::terminal); + + /** + * @brief Checks if this coroutine has been cancelled. + * @return True if the coroutine is cancelled, false otherwise. + */ + [[nodiscard]] bool + isCancelled() const; + + /** + * @brief Returns the cancellable yield context associated with this coroutine. + * This context should be used for Boost.Asio asynchronous operations within the coroutine + * to enable cancellation. + * @return A cancellable_yield_context_type object. + */ + [[nodiscard]] cancellable_yield_context_type + yieldContext() const; + + /** + * @brief Returns the executor associated with this coroutine's yield context. + * @return The executor. + */ + [[nodiscard]] boost::asio::any_io_executor + executor() const; + + /** + * @brief Explicitly yields execution back to the scheduler. + * This can be used to allow other tasks to run. + */ + void + yield() const; +}; + +} // namespace util diff --git a/tests/common/util/AsioContextTestFixture.hpp b/tests/common/util/AsioContextTestFixture.hpp index 27449100..df382321 100644 --- a/tests/common/util/AsioContextTestFixture.hpp +++ b/tests/common/util/AsioContextTestFixture.hpp @@ -19,6 +19,7 @@ #pragma once +#include "util/Coroutine.hpp" #include "util/LoggerFixtures.hpp" #include @@ -79,6 +80,22 @@ private: * This is meant to be used as a base for other fixtures. */ struct SyncAsioContextTest : virtual public NoLoggerFixture { + template + void + runCoroutine(F&& f, bool allowMockLeak = false) + { + testing::MockFunction call; + if (allowMockLeak) + testing::Mock::AllowLeak(&call); + + util::Coroutine::spawnNew(ctx_, [&, _ = boost::asio::make_work_guard(ctx_)](util::Coroutine& coroutine) { + f(coroutine); + call.Call(); + }); + EXPECT_CALL(call, Call()); + runContext(); + } + template void runSpawn(F&& f, bool allowMockLeak = false) diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index e27a6e01..7baca621 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -166,6 +166,7 @@ target_sources( # Common utils util/AccountUtilsTests.cpp util/AssertTests.cpp + util/CoroutineTest.cpp util/MoveTrackerTests.cpp util/RandomTests.cpp util/RetryTests.cpp diff --git a/tests/unit/util/CoroutineTest.cpp b/tests/unit/util/CoroutineTest.cpp new file mode 100644 index 00000000..d648bf69 --- /dev/null +++ b/tests/unit/util/CoroutineTest.cpp @@ -0,0 +1,191 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "util/AsioContextTestFixture.hpp" +#include "util/Coroutine.hpp" +#include "util/Profiler.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace util; + +class CoroutineTest : public SyncAsioContextTest { +protected: + testing::StrictMock> fnMock_; + + static void + asyncOperation(Coroutine::cancellable_yield_context_type yield, std::chrono::steady_clock::duration duration) + { + boost::asio::steady_timer timer(yield.get().get_executor(), duration); + timer.async_wait(yield); + } +}; + +TEST_F(CoroutineTest, SpawnNew) +{ + EXPECT_CALL(fnMock_, Call); + + Coroutine::spawnNew(ctx_, fnMock_.AsStdFunction()); + ctx_.run(); +} + +TEST_F(CoroutineTest, SpawnChild) +{ + EXPECT_CALL(fnMock_, Call); + runCoroutine([this](Coroutine& coroutine) { coroutine.spawnChild(fnMock_.AsStdFunction()); }); +} + +TEST_F(CoroutineTest, SpawnChildDoesNothingWhenTheCoroutineIsCancelled) +{ + runCoroutine([this](Coroutine& coroutine) { + coroutine.cancelAll(); + coroutine.spawnChild(fnMock_.AsStdFunction()); + }); +} + +TEST_F(CoroutineTest, ErrorReturnsDefaultWhenNoError) +{ + runCoroutine([](Coroutine& coroutine) { EXPECT_EQ(coroutine.error(), boost::system::error_code{}); }); +} + +TEST_F(CoroutineTest, ErrorReturnsDefaultAfterSuccessfulOperation) +{ + runCoroutine([](Coroutine& coroutine) { + boost::asio::steady_timer timer(coroutine.executor()); + timer.expires_after(std::chrono::milliseconds(1)); + timer.async_wait(coroutine.yieldContext()); + EXPECT_EQ(coroutine.error(), boost::system::error_code{}); + }); +} + +TEST_F(CoroutineTest, ErrorReturnsErrorOfLastOperation) +{ + runCoroutine([](Coroutine& coroutine) { + boost::asio::ip::tcp::socket socket{coroutine.executor()}; + std::string buffer; + socket.async_read_some(boost::asio::buffer(buffer), coroutine.yieldContext()); + EXPECT_EQ(coroutine.error(), boost::system::errc::bad_file_descriptor); + }); +} + +TEST_F(CoroutineTest, CancelAllCancelsChildren) +{ + runCoroutine([&](Coroutine& coroutine) { + coroutine.spawnChild([](Coroutine& childCoroutine) { + auto const duration = util::timed([&childCoroutine]() { + asyncOperation(childCoroutine.yieldContext(), std::chrono::seconds{5}); + }); + EXPECT_TRUE(childCoroutine.isCancelled()); + EXPECT_LT(duration, 1000); + }); + coroutine.cancelAll(); + EXPECT_TRUE(coroutine.isCancelled()); + }); +} + +TEST_F(CoroutineTest, CancelAllCancelsParent) +{ + runCoroutine([&](Coroutine& coroutine) { + coroutine.spawnChild([](Coroutine& childCoroutine) { + childCoroutine.yield(); + childCoroutine.cancelAll(); + EXPECT_TRUE(childCoroutine.isCancelled()); + }); + + auto const duration = + util::timed([&coroutine]() { asyncOperation(coroutine.yieldContext(), std::chrono::seconds{5}); }); + EXPECT_TRUE(coroutine.isCancelled()); + EXPECT_LT(duration, 1000); + }); +} + +TEST_F(CoroutineTest, CancelAllCalledMultipleTimes) +{ + runCoroutine([&](Coroutine& coroutine) { + coroutine.spawnChild([](Coroutine& childCoroutine) { + childCoroutine.yield(); + for ([[maybe_unused]] auto const i : std::ranges::iota_view(0, 10)) { + childCoroutine.cancelAll(); + } + EXPECT_TRUE(childCoroutine.isCancelled()); + }); + + auto const duration = + util::timed([&coroutine]() { asyncOperation(coroutine.yieldContext(), std::chrono::seconds{5}); }); + EXPECT_TRUE(coroutine.isCancelled()); + EXPECT_LT(duration, 1000); + }); +} + +TEST_F(CoroutineTest, CancelAllCancelsSiblingsAndParent) +{ + EXPECT_CALL(fnMock_, Call).Times(2); + runCoroutine([&](Coroutine& parentCoroutine) { + parentCoroutine.spawnChild([&](Coroutine& child1Coroutine) { + auto duration = util::timed([&child1Coroutine]() { + asyncOperation(child1Coroutine.yieldContext(), std::chrono::seconds(5)); + }); + EXPECT_TRUE(child1Coroutine.isCancelled()); + EXPECT_LT(duration, 2000); + EXPECT_EQ(child1Coroutine.error(), boost::asio::error::operation_aborted); + fnMock_.Call(child1Coroutine); + }); + parentCoroutine.spawnChild([&](Coroutine& child2Coroutine) { + child2Coroutine.yield(); + child2Coroutine.cancelAll(); + fnMock_.Call(child2Coroutine); + }); + + auto parentDuration = util::timed([&parentCoroutine]() { + asyncOperation(parentCoroutine.yieldContext(), std::chrono::seconds(5)); + }); + EXPECT_TRUE(parentCoroutine.isCancelled()); + EXPECT_LT(parentDuration, 2000); + EXPECT_EQ(parentCoroutine.error(), boost::asio::error::operation_aborted); + }); +} + +TEST_F(CoroutineTest, Yield) +{ + testing::StrictMock> anotherFnMock; + testing::Sequence sequence; + EXPECT_CALL(fnMock_, Call).InSequence(sequence); + EXPECT_CALL(anotherFnMock, Call).InSequence(sequence); + + runCoroutine([&](Coroutine& coroutine) { + coroutine.spawnChild([&anotherFnMock](Coroutine& childCoroutine) { + childCoroutine.yield(); + anotherFnMock.Call(); + }); + fnMock_.Call(coroutine); + }); +}