diff --git a/src/ripple/net/DatabaseDownloader.h b/src/ripple/net/DatabaseDownloader.h index cfb811e57..3e9209094 100644 --- a/src/ripple/net/DatabaseDownloader.h +++ b/src/ripple/net/DatabaseDownloader.h @@ -28,12 +28,14 @@ namespace ripple { class DatabaseDownloader : public HTTPDownloader { public: - DatabaseDownloader( - boost::asio::io_service& io_service, - beast::Journal j, - Config const& config); + virtual ~DatabaseDownloader() = default; private: + DatabaseDownloader( + boost::asio::io_service& io_service, + Config const& config, + beast::Journal j); + static const std::uint8_t MAX_PATH_LEN = std::numeric_limits::max(); @@ -54,8 +56,21 @@ private: Config const& config_; boost::asio::io_service& io_service_; + + friend std::shared_ptr + make_DatabaseDownloader( + boost::asio::io_service& io_service, + Config const& config, + beast::Journal j); }; +// DatabaseDownloader must be a shared_ptr because it uses shared_from_this +std::shared_ptr +make_DatabaseDownloader( + boost::asio::io_service& io_service, + Config const& config, + beast::Journal j); + } // namespace ripple #endif // RIPPLE_NET_DATABASEDOWNLOADER_H diff --git a/src/ripple/net/HTTPDownloader.h b/src/ripple/net/HTTPDownloader.h index c8f25ee91..971627e2f 100644 --- a/src/ripple/net/HTTPDownloader.h +++ b/src/ripple/net/HTTPDownloader.h @@ -41,16 +41,11 @@ namespace ripple { /** Provides an asynchronous HTTP[S] file downloader */ -class HTTPDownloader +class HTTPDownloader : public std::enable_shared_from_this { public: using error_code = boost::system::error_code; - HTTPDownloader( - boost::asio::io_service& io_service, - beast::Journal j, - Config const& config); - bool download( std::string const& host, @@ -67,6 +62,13 @@ public: virtual ~HTTPDownloader() = default; protected: + // must be accessed through a shared_ptr + // use make_XXX functions to create + HTTPDownloader( + boost::asio::io_service& io_service, + Config const& config, + beast::Journal j); + using parser = boost::beast::http::basic_parser; beast::Journal const j_; @@ -74,7 +76,6 @@ protected: void fail( boost::filesystem::path dstPath, - std::function const& complete, boost::system::error_code const& ec, std::string const& errMsg, std::shared_ptr parser); @@ -84,7 +85,7 @@ private: boost::asio::io_service::strand strand_; std::unique_ptr stream_; boost::beast::flat_buffer read_buf_; - std::atomic cancelDownloads_; + std::atomic stop_; // Used to protect sessionActive_ std::mutex m_; diff --git a/src/ripple/net/ShardDownloader.md b/src/ripple/net/ShardDownloader.md index a01855a06..96938acb1 100644 --- a/src/ripple/net/ShardDownloader.md +++ b/src/ripple/net/ShardDownloader.md @@ -79,7 +79,7 @@ will be used in the following code examples. ```c++ std::unique_ptr stream_; std::condition_variable c_; -std::atomic cancelDownloads_; +std::atomic stop_; ``` ### Graceful Shutdowns @@ -106,7 +106,7 @@ ShardArchiveHandler::onStop() ``` Inside of `HTTPDownloader::onStop()`, if a download is currently in progress, -the `cancelDownloads_` member variable is set and the thread waits for the +the `stop_` member variable is set and the thread waits for the download to stop: ```c++ @@ -115,7 +115,7 @@ HTTPDownloader::onStop() { std::unique_lock lock(m_); - cancelDownloads_ = true; + stop_ = true; if(sessionActive_) { @@ -132,7 +132,7 @@ HTTPDownloader::onStop() ##### Thread 2: The graceful shutdown is realized when the thread executing the download polls -`cancelDownloads_` after this variable has been set to `true`. Polling occurs +`stop_` after this variable has been set to `true`. Polling occurs while the file is being downloaded, in between calls to `async_read_some()`. The stop takes effect when the socket is closed and the handler function ( `do_session()` ) is exited. @@ -145,7 +145,7 @@ void HTTPDownloader::do_session() // (In between calls to async_read_some): - if(cancelDownloads_.load()) + if(stop_.load()) { close(p); return exit(); diff --git a/src/ripple/net/impl/DatabaseDownloader.cpp b/src/ripple/net/impl/DatabaseDownloader.cpp index ef6f33324..eab0d74d7 100644 --- a/src/ripple/net/impl/DatabaseDownloader.cpp +++ b/src/ripple/net/impl/DatabaseDownloader.cpp @@ -21,11 +21,21 @@ namespace ripple { +std::shared_ptr +make_DatabaseDownloader( + boost::asio::io_service& io_service, + Config const& config, + beast::Journal j) +{ + return std::shared_ptr( + new DatabaseDownloader(io_service, config, j)); +} + DatabaseDownloader::DatabaseDownloader( boost::asio::io_service& io_service, - beast::Journal j, - Config const& config) - : HTTPDownloader(io_service, j, config) + Config const& config, + beast::Journal j) + : HTTPDownloader(io_service, config, j) , config_(config) , io_service_(io_service) { @@ -42,11 +52,9 @@ DatabaseDownloader::getParser( auto p = std::make_shared>(); p->body_limit(std::numeric_limits::max()); p->get().body().open(dstPath, config_, io_service_, ec); + if (ec) - { p->get().body().close(); - fail(dstPath, complete, ec, "open", nullptr); - } return p; } diff --git a/src/ripple/net/impl/HTTPDownloader.cpp b/src/ripple/net/impl/HTTPDownloader.cpp index 03cbf4aa7..f0d554684 100644 --- a/src/ripple/net/impl/HTTPDownloader.cpp +++ b/src/ripple/net/impl/HTTPDownloader.cpp @@ -24,12 +24,12 @@ namespace ripple { HTTPDownloader::HTTPDownloader( boost::asio::io_service& io_service, - beast::Journal j, - Config const& config) + Config const& config, + beast::Journal j) : j_(j) , config_(config) , strand_(io_service) - , cancelDownloads_(false) + , stop_(false) , sessionActive_(false) { } @@ -47,19 +47,18 @@ HTTPDownloader::download( if (!checkPath(dstPath)) return false; + if (stop_) + return true; + { std::lock_guard lock(m_); - - if (cancelDownloads_) - return true; - sessionActive_ = true; } if (!strand_.running_in_this_thread()) strand_.post(std::bind( &HTTPDownloader::download, - this, + shared_from_this(), host, port, target, @@ -72,7 +71,7 @@ HTTPDownloader::download( strand_, std::bind( &HTTPDownloader::do_session, - this, + shared_from_this(), host, port, target, @@ -84,20 +83,6 @@ HTTPDownloader::download( return true; } -void -HTTPDownloader::onStop() -{ - std::unique_lock lock(m_); - - cancelDownloads_ = true; - - if (sessionActive_) - { - // Wait for the handler to exit. - c_.wait(lock, [this]() { return !sessionActive_; }); - } -} - void HTTPDownloader::do_session( std::string const host, @@ -140,21 +125,24 @@ HTTPDownloader::do_session( // because the server is shutting down, // this method notifies a 'Stoppable' // object that the session has ended. - auto exit = [this]() { - std::lock_guard lock(m_); + auto exit = [this, &dstPath, complete] { + if (!stop_) + complete(std::move(dstPath)); + + std::lock_guard lock(m_); sessionActive_ = false; c_.notify_one(); }; auto failAndExit = [&exit, &dstPath, complete, &ec, this]( std::string const& errMsg, auto p) { + fail(dstPath, ec, errMsg, p); exit(); - fail(dstPath, complete, ec, errMsg, p); }; // end lambdas //////////////////////////////////////////////////////////// - if (cancelDownloads_.load()) + if (stop_.load()) return exit(); auto p = this->getParser(dstPath, complete, ec); @@ -274,7 +262,7 @@ HTTPDownloader::do_session( // Download the file while (!p->is_done()) { - if (cancelDownloads_.load()) + if (stop_.load()) { close(p); return exit(); @@ -287,15 +275,24 @@ HTTPDownloader::do_session( close(p); exit(); +} - // Notify the completion handler - complete(std::move(dstPath)); +void +HTTPDownloader::onStop() +{ + stop_ = true; + + std::unique_lock lock(m_); + if (sessionActive_) + { + // Wait for the handler to exit. + c_.wait(lock, [this]() { return !sessionActive_; }); + } } void HTTPDownloader::fail( boost::filesystem::path dstPath, - std::function const& complete, boost::system::error_code const& ec, std::string const& errMsg, std::shared_ptr parser) @@ -321,7 +318,6 @@ HTTPDownloader::fail( JLOG(j_.error()) << "exception: " << e.what() << " in function: " << __func__; } - complete(std::move(dstPath)); } } // namespace ripple diff --git a/src/ripple/rpc/ShardArchiveHandler.h b/src/ripple/rpc/ShardArchiveHandler.h index 38e62eef8..2e2133bc2 100644 --- a/src/ripple/rpc/ShardArchiveHandler.h +++ b/src/ripple/rpc/ShardArchiveHandler.h @@ -130,7 +130,7 @@ private: // destroying sqliteDB_. ///////////////////////////////////////////////// std::mutex mutable m_; - std::unique_ptr downloader_; + std::shared_ptr downloader_; std::map archives_; bool process_; std::unique_ptr sqliteDB_; diff --git a/src/ripple/rpc/impl/ShardArchiveHandler.cpp b/src/ripple/rpc/impl/ShardArchiveHandler.cpp index 242d73887..cabee0863 100644 --- a/src/ripple/rpc/impl/ShardArchiveHandler.cpp +++ b/src/ripple/rpc/impl/ShardArchiveHandler.cpp @@ -102,8 +102,8 @@ ShardArchiveHandler::init() if (exists(downloadDir_ / stateDBName) && is_regular_file(downloadDir_ / stateDBName)) { - downloader_.reset( - new DatabaseDownloader(app_.getIOService(), j_, app_.config())); + downloader_ = + make_DatabaseDownloader(app_.getIOService(), app_.config(), j_); return initFromDB(lock); } @@ -283,8 +283,8 @@ ShardArchiveHandler::start() if (!downloader_) { // will throw if can't initialize ssl context - downloader_ = std::make_unique( - app_.getIOService(), j_, app_.config()); + downloader_ = + make_DatabaseDownloader(app_.getIOService(), app_.config(), j_); } } catch (std::exception const& e) @@ -307,15 +307,15 @@ ShardArchiveHandler::release() bool ShardArchiveHandler::next(std::lock_guard const& l) { + if (isStopping()) + return false; + if (archives_.empty()) { doRelease(l); return false; } - if (isStopping()) - return false; - auto const shardIndex{archives_.begin()->first}; // We use the sequence of the last validated ledger diff --git a/src/test/net/DatabaseDownloader_test.cpp b/src/test/net/DatabaseDownloader_test.cpp index 6333add3f..7be70e004 100644 --- a/src/test/net/DatabaseDownloader_test.cpp +++ b/src/test/net/DatabaseDownloader_test.cpp @@ -80,19 +80,22 @@ class DatabaseDownloader_test : public beast::unit_test::suite { test::StreamSink sink_; beast::Journal journal_; - // The DatabaseDownloader must be created as shared_ptr - // because it uses shared_from_this std::shared_ptr ptr_; Downloader(jtx::Env& env) : journal_{sink_} - , ptr_{std::make_shared( + , ptr_{make_DatabaseDownloader( env.app().getIOService(), - journal_, - env.app().config())} + env.app().config(), + journal_)} { } + ~Downloader() + { + ptr_->onStop(); + } + DatabaseDownloader* operator->() {