Compare commits

...

23 Commits

Author SHA1 Message Date
Sergey Kuznetsov
9df3e936cc chore: Update libxrpl to 2.3.0-b4 (#1667) 2024-09-25 14:44:03 +01:00
Alex Kremer
4166c46820 fix: Workaround for gcc12 bug with defaulted destructors (#1666)
Fixes #1662
2024-09-25 14:44:03 +01:00
github-actions[bot]
f75cbd456b style: clang-tidy auto fixes (#1663)
Fixes #1662.

---------

Co-authored-by: kuznetsss <15742918+kuznetsss@users.noreply.github.com>
Co-authored-by: Peter Chen <ychen@ripple.com>
2024-09-25 14:44:03 +01:00
Peter Chen
d189651821 fix: add no lint to ignore clang-tidy (#1660)
Fixes build for
[#1659](https://github.com/XRPLF/clio/actions/runs/10956058143/job/30421296417)
2024-09-25 14:44:02 +01:00
github-actions[bot]
3f791c1315 style: clang-tidy auto fixes (#1659)
Fixes #1658. Please review and commit clang-tidy fixes.

Co-authored-by: kuznetsss <15742918+kuznetsss@users.noreply.github.com>
2024-09-25 14:44:02 +01:00
Peter Chen
418511332e chore: Revert Cassandra driver upgrade (#1656)
Reverts XRPLF/clio#1646
2024-09-25 14:44:02 +01:00
Peter Chen
e5a0477352 refactor: Clio Config (#1593)
Add constraint + parse json into Config
Second part of refactoring Clio Config; First PR found
[here](https://github.com/XRPLF/clio/pull/1544)

Steps that are left to implement:
- Replacing all the places where we fetch config values (by using
config.valueOr/MaybeValue) to instead get it from Config Definition
- Generate markdown file using Clio Config Description
2024-09-25 14:44:02 +01:00
cyan317
3118110eb8 feat: add 'force_forward' field to request (#1647)
Fix #1141
2024-09-25 14:44:01 +01:00
Alex Kremer
6d20f39f67 feat: Delete-before support in data removal tool (#1649)
Fixes #1650
2024-09-25 14:44:01 +01:00
Peter Chen
9cb1e06c8e fix: Upgrade Cassandra driver (#1646)
Fixes #1296
2024-09-25 14:44:01 +01:00
Peter Chen
423244eb4b fix: pre-push tag (#1614)
Fix issue of git was verifying incorrect Tag
2024-09-25 14:44:01 +01:00
cyan317
7aaba1cbad fix: no restriction on type field (#1644)
'type' should not matter if 'full' or 'accounts' is false. Relax the
restriction for 'type'
2024-09-25 14:44:00 +01:00
cyan317
b7c50fd73d fix: Add more restrictions to admin fields (#1643) 2024-09-25 14:44:00 +01:00
dependabot[bot]
442ee874d5 ci: Bump peter-evans/create-pull-request from 6 to 7 (#1636)
Bumps
[peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request)
from 6 to 7.
<details>
<summary>Release notes</summary>
<p><em>Sourced from <a
href="https://github.com/peter-evans/create-pull-request/releases">peter-evans/create-pull-request's
releases</a>.</em></p>
<blockquote>
<h2>Create Pull Request v7.0.0</h2>
<p> Now supports commit signing with bot-generated tokens! See
&quot;What's new&quot; below. ✍️🤖</p>
<h3>Behaviour changes</h3>
<ul>
<li>Action input <code>git-token</code> has been renamed
<code>branch-token</code>, to be more clear about its purpose. The
<code>branch-token</code> is the token that the action will use to
create and update the branch.</li>
<li>The action now handles requests that have been rate-limited by
GitHub. Requests hitting a primary rate limit will retry twice, for a
total of three attempts. Requests hitting a secondary rate limit will
not be retried.</li>
<li>The <code>pull-request-operation</code> output now returns
<code>none</code> when no operation was executed.</li>
<li>Removed deprecated output environment variable
<code>PULL_REQUEST_NUMBER</code>. Please use the
<code>pull-request-number</code> action output instead.</li>
</ul>
<h3>What's new</h3>
<ul>
<li>The action can now sign commits as <code>github-actions[bot]</code>
when using <code>GITHUB_TOKEN</code>, or your own bot when using <a
href="https://github.com/peter-evans/create-pull-request/blob/HEAD/docs/concepts-guidelines.md#authenticating-with-github-app-generated-tokens">GitHub
App tokens</a>. See <a
href="https://github.com/peter-evans/create-pull-request/blob/HEAD/docs/concepts-guidelines.md#commit-signature-verification-for-bots">commit
signing</a> for details.</li>
<li>Action input <code>draft</code> now accepts a new value
<code>always-true</code>. This will set the pull request to draft status
when the pull request is updated, as well as on creation.</li>
<li>A new action input <code>maintainer-can-modify</code> indicates
whether <a
href="https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/allowing-changes-to-a-pull-request-branch-created-from-a-fork">maintainers
can modify</a> the pull request. The default is <code>true</code>, which
retains the existing behaviour of the action.</li>
<li>A new output <code>pull-request-commits-verified</code> returns
<code>true</code> or <code>false</code>, indicating whether GitHub
considers the signature of the branch's commits to be verified.</li>
</ul>
<h2>What's Changed</h2>
<ul>
<li>build(deps-dev): bump <code>@​types/node</code> from 18.19.36 to
18.19.39 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3000">peter-evans/create-pull-request#3000</a></li>
<li>build(deps-dev): bump ts-jest from 29.1.5 to 29.2.0 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3008">peter-evans/create-pull-request#3008</a></li>
<li>build(deps-dev): bump prettier from 3.3.2 to 3.3.3 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3018">peter-evans/create-pull-request#3018</a></li>
<li>build(deps-dev): bump ts-jest from 29.2.0 to 29.2.2 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3019">peter-evans/create-pull-request#3019</a></li>
<li>build(deps-dev): bump eslint-plugin-prettier from 5.1.3 to 5.2.1 by
<a href="https://github.com/dependabot"><code>@​dependabot</code></a> in
<a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3035">peter-evans/create-pull-request#3035</a></li>
<li>build(deps-dev): bump <code>@​types/node</code> from 18.19.39 to
18.19.41 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3037">peter-evans/create-pull-request#3037</a></li>
<li>build(deps): bump undici from 6.19.2 to 6.19.4 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3036">peter-evans/create-pull-request#3036</a></li>
<li>build(deps-dev): bump ts-jest from 29.2.2 to 29.2.3 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3038">peter-evans/create-pull-request#3038</a></li>
<li>build(deps-dev): bump <code>@​types/node</code> from 18.19.41 to
18.19.42 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3070">peter-evans/create-pull-request#3070</a></li>
<li>build(deps): bump undici from 6.19.4 to 6.19.5 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3086">peter-evans/create-pull-request#3086</a></li>
<li>build(deps-dev): bump <code>@​types/node</code> from 18.19.42 to
18.19.43 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3087">peter-evans/create-pull-request#3087</a></li>
<li>build(deps-dev): bump ts-jest from 29.2.3 to 29.2.4 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3088">peter-evans/create-pull-request#3088</a></li>
<li>build(deps): bump undici from 6.19.5 to 6.19.7 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3145">peter-evans/create-pull-request#3145</a></li>
<li>build(deps-dev): bump <code>@​types/node</code> from 18.19.43 to
18.19.44 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3144">peter-evans/create-pull-request#3144</a></li>
<li>Update distribution by <a
href="https://github.com/actions-bot"><code>@​actions-bot</code></a> in
<a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3154">peter-evans/create-pull-request#3154</a></li>
<li>build(deps): bump undici from 6.19.7 to 6.19.8 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3213">peter-evans/create-pull-request#3213</a></li>
<li>build(deps-dev): bump <code>@​types/node</code> from 18.19.44 to
18.19.45 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3214">peter-evans/create-pull-request#3214</a></li>
<li>Update distribution by <a
href="https://github.com/actions-bot"><code>@​actions-bot</code></a> in
<a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3221">peter-evans/create-pull-request#3221</a></li>
<li>build(deps-dev): bump eslint-import-resolver-typescript from 3.6.1
to 3.6.3 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3255">peter-evans/create-pull-request#3255</a></li>
<li>build(deps-dev): bump <code>@​types/node</code> from 18.19.45 to
18.19.46 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3254">peter-evans/create-pull-request#3254</a></li>
<li>build(deps-dev): bump ts-jest from 29.2.4 to 29.2.5 by <a
href="https://github.com/dependabot"><code>@​dependabot</code></a> in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3256">peter-evans/create-pull-request#3256</a></li>
<li>v7 - signed commits by <a
href="https://github.com/peter-evans"><code>@​peter-evans</code></a> in
<a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3057">peter-evans/create-pull-request#3057</a></li>
</ul>
<h2>New Contributors</h2>
<ul>
<li><a
href="https://github.com/rustycl0ck"><code>@​rustycl0ck</code></a> made
their first contribution in <a
href="https://redirect.github.com/peter-evans/create-pull-request/pull/3057">peter-evans/create-pull-request#3057</a></li>
</ul>
<p><strong>Full Changelog</strong>: <a
href="https://github.com/peter-evans/create-pull-request/compare/v6.1.0...v7.0.0">https://github.com/peter-evans/create-pull-request/compare/v6.1.0...v7.0.0</a></p>
<h2>Create Pull Request v6.1.0</h2>
<p> Adds <code>pull-request-branch</code> as an action output.</p>
<h2>What's Changed</h2>
<!-- raw HTML omitted -->
</blockquote>
<p>... (truncated)</p>
</details>
<details>
<summary>Commits</summary>
<ul>
<li><a
href="8867c4aba1"><code>8867c4a</code></a>
fix: handle ambiguous argument failure on diff stat (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3312">#3312</a>)</li>
<li><a
href="6073f5434b"><code>6073f54</code></a>
build(deps-dev): bump <code>@​typescript-eslint/eslint-plugin</code> (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3291">#3291</a>)</li>
<li><a
href="6d01b5601c"><code>6d01b56</code></a>
build(deps-dev): bump eslint-plugin-import from 2.29.1 to 2.30.0 (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3290">#3290</a>)</li>
<li><a
href="25cf8451c3"><code>25cf845</code></a>
build(deps-dev): bump <code>@​typescript-eslint/parser</code> from
7.17.0 to 7.18.0 (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3289">#3289</a>)</li>
<li><a
href="d87b980a0e"><code>d87b980</code></a>
build(deps-dev): bump <code>@​types/node</code> from 18.19.46 to
18.19.48 (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3288">#3288</a>)</li>
<li><a
href="119d131ea9"><code>119d131</code></a>
build(deps): bump peter-evans/create-pull-request from 6 to 7 (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3283">#3283</a>)</li>
<li><a
href="73e6230af4"><code>73e6230</code></a>
docs: update readme</li>
<li><a
href="c0348e860f"><code>c0348e8</code></a>
ci: add v7 to workflow</li>
<li><a
href="4320041ed3"><code>4320041</code></a>
feat: signed commits (v7) (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3057">#3057</a>)</li>
<li><a
href="0c2a66fe4a"><code>0c2a66f</code></a>
build(deps-dev): bump ts-jest from 29.2.4 to 29.2.5 (<a
href="https://redirect.github.com/peter-evans/create-pull-request/issues/3256">#3256</a>)</li>
<li>Additional commits viewable in <a
href="https://github.com/peter-evans/create-pull-request/compare/v6...v7">compare
view</a></li>
</ul>
</details>
<br />


[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=peter-evans/create-pull-request&package-manager=github_actions&previous-version=6&new-version=7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)

Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.

[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)

---

<details>
<summary>Dependabot commands and options</summary>
<br />

You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show <dependency name> ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)


</details>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-25 14:44:00 +01:00
cyan317
0679034978 fix: Don't forward ledger API if 'full' is a string (#1640)
Fix #1635
2024-09-25 14:43:59 +01:00
github-actions[bot]
b41ea34212 style: clang-tidy auto fixes (#1639)
Fixes #1638. Please review and commit clang-tidy fixes.

Co-authored-by: kuznetsss <15742918+kuznetsss@users.noreply.github.com>
2024-09-25 14:43:59 +01:00
Sergey Kuznetsov
4e147deafa fix: Subscription source bugs fix (#1626) (#1633)
Fixes #1620.
Cherry pick of #1626 into develop.

- Add timeouts for websocket operations for connections to rippled.
Without these timeouts if connection hangs for some reason, clio
wouldn't know the connection is hanging.
- Fix potential data race in choosing new subscription source which will
forward messages to users.
- Optimise switching between subscription sources.
2024-09-25 14:43:59 +01:00
Sergey Kuznetsov
b08447e8e0 fix: Fix logging in SubscriptionSource (#1617) (#1632)
Fixes #1616. 
Cherry pick of #1617 into develop.
2024-09-25 14:43:59 +01:00
cyan317
9432165ace refactor: Remove SubscriptionManagerRunner (#1623) 2024-09-25 14:43:58 +01:00
github-actions[bot]
3b6a87249c style: clang-tidy auto fixes (#1631)
Fixes #1630. Please review and commit clang-tidy fixes.

Co-authored-by: kuznetsss <15742918+kuznetsss@users.noreply.github.com>
2024-09-25 14:43:58 +01:00
Sergey Kuznetsov
b7449f72b7 test: Add test for WsConnection for ping response (#1619) 2024-09-25 14:43:58 +01:00
cyan317
443c74436e fix: not forward admin API (#1628) 2024-09-25 14:43:57 +01:00
Peter Chen
7b5e02731d fix: AccountNFT with invalid marker (#1589)
Fixes [#1497](https://github.com/XRPLF/clio/issues/1497)
Mimics the behavior of the [fix on Rippled
side](https://github.com/XRPLF/rippled/pull/5045)
2024-08-27 15:38:19 -04:00
66 changed files with 3591 additions and 1384 deletions

View File

@@ -42,7 +42,7 @@ verify_tag_signed() {
while read local_ref local_oid remote_ref remote_oid; do
# Check some things if we're pushing a branch called "release/"
if echo "$remote_ref" | grep ^refs\/heads\/release\/ &> /dev/null ; then
version=$(echo $remote_ref | awk -F/ '{print $NF}')
version=$(git tag --points-at HEAD)
echo "Looks like you're trying to push a $version release..."
echo "Making sure you've signed and tagged it."
if verify_commit_signed && verify_tag && verify_tag_signed ; then

View File

@@ -99,7 +99,7 @@ jobs:
- name: Create PR with fixes
if: ${{ steps.run_clang_tidy.outcome != 'success' }}
uses: peter-evans/create-pull-request@v6
uses: peter-evans/create-pull-request@v7
env:
GH_REPO: ${{ github.repository }}
GH_TOKEN: ${{ github.token }}

View File

@@ -28,7 +28,7 @@ class Clio(ConanFile):
'protobuf/3.21.9',
'grpc/1.50.1',
'openssl/1.1.1u',
'xrpl/2.3.0-b1',
'xrpl/2.3.0-b4',
'libbacktrace/cci.20210118'
]

View File

@@ -101,9 +101,7 @@ ClioApplication::run()
auto backend = data::make_Backend(config_);
// Manages clients subscribed to streams
auto subscriptionsRunner = feed::SubscriptionManagerRunner(config_, backend);
auto const subscriptions = subscriptionsRunner.getManager();
auto subscriptions = feed::SubscriptionManager::make_SubscriptionManager(config_, backend);
// Tracks which ledgers have been validated by the network
auto ledgers = etl::NetworkValidatedLedgers::make_ValidatedLedgers();

View File

@@ -124,6 +124,9 @@ struct Amendments {
REGISTER(NFTokenMintOffer);
REGISTER(fixReducedOffersV2);
REGISTER(fixEnforceNFTokenTrustline);
REGISTER(fixInnerObjTemplate2);
REGISTER(fixNFTokenPageLinks);
REGISTER(InvariantsV1_1);
// Obsolete but supported by libxrpl
REGISTER(CryptoConditionsSuite);

View File

@@ -111,10 +111,13 @@ LoadBalancer::LoadBalancer(
validatedLedgers,
forwardingTimeout,
[this]() {
if (not hasForwardingSource_)
if (not hasForwardingSource_.lock().get())
chooseForwardingSource();
},
[this](bool wasForwarding) {
if (wasForwarding)
chooseForwardingSource();
},
[this]() { chooseForwardingSource(); },
[this]() {
if (forwardingCache_.has_value())
forwardingCache_->invalidate();
@@ -322,11 +325,13 @@ LoadBalancer::getETLState() noexcept
void
LoadBalancer::chooseForwardingSource()
{
hasForwardingSource_ = false;
LOG(log_.info()) << "Choosing a new source to forward subscriptions";
auto hasForwardingSourceLock = hasForwardingSource_.lock();
hasForwardingSourceLock.get() = false;
for (auto& source : sources_) {
if (not hasForwardingSource_ and source->isConnected()) {
if (not hasForwardingSourceLock.get() and source->isConnected()) {
source->setForwarding(true);
hasForwardingSource_ = true;
hasForwardingSourceLock.get() = true;
} else {
source->setForwarding(false);
}

View File

@@ -25,7 +25,7 @@
#include "etl/Source.hpp"
#include "etl/impl/ForwardingCache.hpp"
#include "feed/SubscriptionManagerInterface.hpp"
#include "rpc/Errors.hpp"
#include "util/Mutex.hpp"
#include "util/config/Config.hpp"
#include "util/log/Logger.hpp"
@@ -39,7 +39,6 @@
#include <org/xrpl/rpc/v1/ledger.pb.h>
#include <xrpl/proto/org/xrpl/rpc/v1/xrp_ledger.grpc.pb.h>
#include <atomic>
#include <chrono>
#include <cstdint>
#include <expected>
@@ -76,7 +75,10 @@ private:
std::optional<ETLState> etlState_;
std::uint32_t downloadRanges_ =
DEFAULT_DOWNLOAD_RANGES; /*< The number of markers to use when downloading initial ledger */
std::atomic_bool hasForwardingSource_{false};
// Using mutext instead of atomic_bool because choosing a new source to
// forward messages should be done with a mutual exclusion otherwise there will be a race condition
util::Mutex<bool> hasForwardingSource_{false};
public:
/**

View File

@@ -53,7 +53,7 @@ namespace etl {
class SourceBase {
public:
using OnConnectHook = std::function<void()>;
using OnDisconnectHook = std::function<void()>;
using OnDisconnectHook = std::function<void(bool)>;
using OnLedgerClosedHook = std::function<void()>;
virtual ~SourceBase() = default;

View File

@@ -47,7 +47,7 @@
namespace etl::impl {
GrpcSource::GrpcSource(std::string const& ip, std::string const& grpcPort, std::shared_ptr<BackendInterface> backend)
: log_(fmt::format("ETL_Grpc[{}:{}]", ip, grpcPort)), backend_(std::move(backend))
: log_(fmt::format("GrpcSource[{}:{}]", ip, grpcPort)), backend_(std::move(backend))
{
try {
boost::asio::io_context ctx;

View File

@@ -24,6 +24,8 @@
#include "rpc/JS.hpp"
#include "util/Retry.hpp"
#include "util/log/Logger.hpp"
#include "util/prometheus/Label.hpp"
#include "util/prometheus/Prometheus.hpp"
#include "util/requests/Types.hpp"
#include <boost/algorithm/string/classification.hpp>
@@ -66,22 +68,28 @@ SubscriptionSource::SubscriptionSource(
OnConnectHook onConnect,
OnDisconnectHook onDisconnect,
OnLedgerClosedHook onLedgerClosed,
std::chrono::steady_clock::duration const connectionTimeout,
std::chrono::steady_clock::duration const wsTimeout,
std::chrono::steady_clock::duration const retryDelay
)
: log_(fmt::format("GrpcSource[{}:{}]", ip, wsPort))
: log_(fmt::format("SubscriptionSource[{}:{}]", ip, wsPort))
, wsConnectionBuilder_(ip, wsPort)
, validatedLedgers_(std::move(validatedLedgers))
, subscriptions_(std::move(subscriptions))
, strand_(boost::asio::make_strand(ioContext))
, wsTimeout_(wsTimeout)
, retry_(util::makeRetryExponentialBackoff(retryDelay, RETRY_MAX_DELAY, strand_))
, onConnect_(std::move(onConnect))
, onDisconnect_(std::move(onDisconnect))
, onLedgerClosed_(std::move(onLedgerClosed))
, lastMessageTimeSecondsSinceEpoch_(PrometheusService::gaugeInt(
"subscription_source_last_message_time",
util::prometheus::Labels({{"source", fmt::format("{}:{}", ip, wsPort)}}),
"Seconds since epoch of the last message received from rippled subscription streams"
))
{
wsConnectionBuilder_.addHeader({boost::beast::http::field::user_agent, "clio-client"})
.addHeader({"X-User", "clio-client"})
.setConnectionTimeout(connectionTimeout);
.setConnectionTimeout(wsTimeout_);
}
SubscriptionSource::~SubscriptionSource()
@@ -133,6 +141,7 @@ void
SubscriptionSource::setForwarding(bool isForwarding)
{
isForwarding_ = isForwarding;
LOG(log_.info()) << "Forwarding set to " << isForwarding_;
}
std::chrono::steady_clock::time_point
@@ -166,20 +175,22 @@ SubscriptionSource::subscribe()
}
wsConnection_ = std::move(connection).value();
isConnected_ = true;
onConnect_();
auto const& subscribeCommand = getSubscribeCommandJson();
auto const writeErrorOpt = wsConnection_->write(subscribeCommand, yield);
auto const writeErrorOpt = wsConnection_->write(subscribeCommand, yield, wsTimeout_);
if (writeErrorOpt) {
handleError(writeErrorOpt.value(), yield);
return;
}
isConnected_ = true;
LOG(log_.info()) << "Connected";
onConnect_();
retry_.reset();
while (!stop_) {
auto const message = wsConnection_->read(yield);
auto const message = wsConnection_->read(yield, wsTimeout_);
if (not message) {
handleError(message.error(), yield);
return;
@@ -224,10 +235,11 @@ SubscriptionSource::handleMessage(std::string const& message)
auto validatedLedgers = boost::json::value_to<std::string>(result.at(JS(validated_ledgers)));
setValidatedRange(std::move(validatedLedgers));
}
LOG(log_.info()) << "Received a message on ledger subscription stream. Message : " << object;
LOG(log_.debug()) << "Received a message on ledger subscription stream. Message: " << object;
} else if (object.contains(JS(type)) && object.at(JS(type)) == JS_LedgerClosed) {
LOG(log_.info()) << "Received a message on ledger subscription stream. Message : " << object;
LOG(log_.debug()) << "Received a message of type 'ledgerClosed' on ledger subscription stream. Message: "
<< object;
if (object.contains(JS(ledger_index))) {
ledgerIndex = object.at(JS(ledger_index)).as_int64();
}
@@ -245,10 +257,13 @@ SubscriptionSource::handleMessage(std::string const& message)
// 2 - Validated transaction
// Only forward proposed transaction, validated transactions are sent by Clio itself
if (object.contains(JS(transaction)) and !object.contains(JS(meta))) {
LOG(log_.debug()) << "Forwarding proposed transaction: " << object;
subscriptions_->forwardProposedTransaction(object);
} else if (object.contains(JS(type)) && object.at(JS(type)) == JS_ValidationReceived) {
LOG(log_.debug()) << "Forwarding validation: " << object;
subscriptions_->forwardValidation(object);
} else if (object.contains(JS(type)) && object.at(JS(type)) == JS_ManifestReceived) {
LOG(log_.debug()) << "Forwarding manifest: " << object;
subscriptions_->forwardManifest(object);
}
}
@@ -261,7 +276,7 @@ SubscriptionSource::handleMessage(std::string const& message)
return std::nullopt;
} catch (std::exception const& e) {
LOG(log_.error()) << "Exception in handleMessage : " << e.what();
LOG(log_.error()) << "Exception in handleMessage: " << e.what();
return util::requests::RequestError{fmt::format("Error handling message: {}", e.what())};
}
}
@@ -270,16 +285,14 @@ void
SubscriptionSource::handleError(util::requests::RequestError const& error, boost::asio::yield_context yield)
{
isConnected_ = false;
isForwarding_ = false;
bool const wasForwarding = isForwarding_.exchange(false);
if (not stop_) {
onDisconnect_();
LOG(log_.info()) << "Disconnected";
onDisconnect_(wasForwarding);
}
if (wsConnection_ != nullptr) {
auto const err = wsConnection_->close(yield);
if (err) {
LOG(log_.error()) << "Error closing websocket connection: " << err->message();
}
wsConnection_->close(yield);
wsConnection_.reset();
}
@@ -306,7 +319,11 @@ SubscriptionSource::logError(util::requests::RequestError const& error) const
void
SubscriptionSource::setLastMessageTime()
{
lastMessageTime_.lock().get() = std::chrono::steady_clock::now();
lastMessageTimeSecondsSinceEpoch_.get().set(
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count()
);
auto lock = lastMessageTime_.lock();
lock.get() = std::chrono::steady_clock::now();
}
void

View File

@@ -25,6 +25,7 @@
#include "util/Mutex.hpp"
#include "util/Retry.hpp"
#include "util/log/Logger.hpp"
#include "util/prometheus/Gauge.hpp"
#include "util/requests/Types.hpp"
#include "util/requests/WsConnection.hpp"
@@ -37,6 +38,7 @@
#include <atomic>
#include <chrono>
#include <cstdint>
#include <functional>
#include <future>
#include <memory>
#include <optional>
@@ -71,6 +73,8 @@ private:
boost::asio::strand<boost::asio::io_context::executor_type> strand_;
std::chrono::steady_clock::duration wsTimeout_;
util::Retry retry_;
OnConnectHook onConnect_;
@@ -83,9 +87,11 @@ private:
util::Mutex<std::chrono::steady_clock::time_point> lastMessageTime_;
std::reference_wrapper<util::prometheus::GaugeInt> lastMessageTimeSecondsSinceEpoch_;
std::future<void> runFuture_;
static constexpr std::chrono::seconds CONNECTION_TIMEOUT{30};
static constexpr std::chrono::seconds WS_TIMEOUT{30};
static constexpr std::chrono::seconds RETRY_MAX_DELAY{30};
static constexpr std::chrono::seconds RETRY_DELAY{1};
@@ -103,7 +109,7 @@ public:
* @param onDisconnect The onDisconnect hook. Called when the connection is lost
* @param onLedgerClosed The onLedgerClosed hook. Called when the ledger is closed but only if the source is
* forwarding
* @param connectionTimeout The connection timeout. Defaults to 30 seconds
* @param wsTimeout A timeout for websocket operations. Defaults to 30 seconds
* @param retryDelay The retry delay. Defaults to 1 second
*/
SubscriptionSource(
@@ -115,7 +121,7 @@ public:
OnConnectHook onConnect,
OnDisconnectHook onDisconnect,
OnLedgerClosedHook onLedgerClosed,
std::chrono::steady_clock::duration const connectionTimeout = CONNECTION_TIMEOUT,
std::chrono::steady_clock::duration const wsTimeout = WS_TIMEOUT,
std::chrono::steady_clock::duration const retryDelay = RETRY_DELAY
);

View File

@@ -30,6 +30,7 @@
#include "feed/impl/TransactionFeed.hpp"
#include "util/async/AnyExecutionContext.hpp"
#include "util/async/context/BasicExecutionContext.hpp"
#include "util/config/Config.hpp"
#include "util/log/Logger.hpp"
#include <boost/asio/executor_work_guard.hpp>
@@ -44,6 +45,7 @@
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
/**
@@ -67,16 +69,36 @@ class SubscriptionManager : public SubscriptionManagerInterface {
impl::ProposedTransactionFeed proposedTransactionFeed_;
public:
/**
* @brief Factory function to create a new SubscriptionManager with a PoolExecutionContext.
*
* @param config The configuration to use
* @param backend The backend to use
* @return A shared pointer to a new instance of SubscriptionManager
*/
static std::shared_ptr<SubscriptionManager>
make_SubscriptionManager(util::Config const& config, std::shared_ptr<data::BackendInterface const> const& backend)
{
auto const workersNum = config.valueOr<std::uint64_t>("subscription_workers", 1);
util::Logger const logger{"Subscriptions"};
LOG(logger.info()) << "Starting subscription manager with " << workersNum << " workers";
return std::make_shared<feed::SubscriptionManager>(util::async::PoolExecutionContext(workersNum), backend);
}
/**
* @brief Construct a new Subscription Manager object
*
* @param executor The executor to use to publish the feeds
* @param backend The backend to use
*/
template <class ExecutorCtx>
SubscriptionManager(ExecutorCtx& executor, std::shared_ptr<data::BackendInterface const> const& backend)
SubscriptionManager(
util::async::AnyExecutionContext&& executor,
std::shared_ptr<data::BackendInterface const> const& backend
)
: backend_(backend)
, ctx_(executor)
, ctx_(std::move(executor))
, manifestFeed_(ctx_, "manifest")
, validationsFeed_(ctx_, "validations")
, ledgerFeed_(ctx_)
@@ -291,41 +313,4 @@ public:
report() const final;
};
/**
* @brief The help class to run the subscription manager. The container of PoolExecutionContext which is used to publish
* the feeds.
*/
class SubscriptionManagerRunner {
std::uint64_t workersNum_;
using ActualExecutionCtx = util::async::PoolExecutionContext;
ActualExecutionCtx ctx_;
std::shared_ptr<SubscriptionManager> subscriptionManager_;
util::Logger logger_{"Subscriptions"};
public:
/**
* @brief Construct a new Subscription Manager Runner object
*
* @param config The configuration
* @param backend The backend to use
*/
SubscriptionManagerRunner(util::Config const& config, std::shared_ptr<data::BackendInterface> const& backend)
: workersNum_(config.valueOr<std::uint64_t>("subscription_workers", 1))
, ctx_(workersNum_)
, subscriptionManager_(std::make_shared<SubscriptionManager>(ctx_, backend))
{
LOG(logger_.info()) << "Starting subscription manager with " << workersNum_ << " workers";
}
/**
* @brief Get the subscription manager
*
* @return The subscription manager
*/
std::shared_ptr<SubscriptionManager>
getManager()
{
return subscriptionManager_;
}
};
} // namespace feed

View File

@@ -22,6 +22,7 @@
#include "data/BackendInterface.hpp"
#include "rpc/Counters.hpp"
#include "rpc/Errors.hpp"
#include "rpc/RPCHelpers.hpp"
#include "rpc/WorkQueue.hpp"
#include "rpc/common/HandlerProvider.hpp"
#include "rpc/common/Types.hpp"
@@ -131,8 +132,13 @@ public:
Result
buildResponse(web::Context const& ctx)
{
if (forwardingProxy_.shouldForward(ctx))
if (forwardingProxy_.shouldForward(ctx)) {
// Disallow forwarding of the admin api, only user api is allowed for security reasons.
if (isAdminCmd(ctx.method, ctx.params))
return Result{Status{RippledError::rpcNO_PERMISSION}};
return forwardingProxy_.forward(ctx);
}
if (backend_->isTooBusy()) {
LOG(log_.error()) << "Database is too busy. Rejecting request";

View File

@@ -36,6 +36,7 @@
#include <boost/json/array.hpp>
#include <boost/json/object.hpp>
#include <boost/json/parse.hpp>
#include <boost/json/serialize.hpp>
#include <boost/json/string.hpp>
#include <boost/json/value.hpp>
#include <boost/json/value_to.hpp>
@@ -49,6 +50,7 @@
#include <xrpl/basics/chrono.h>
#include <xrpl/basics/strHex.h>
#include <xrpl/beast/utility/Zero.h>
#include <xrpl/json/json_reader.h>
#include <xrpl/json/json_value.h>
#include <xrpl/protocol/AccountID.h>
#include <xrpl/protocol/Book.h>
@@ -1273,6 +1275,31 @@ specifiesCurrentOrClosedLedger(boost::json::object const& request)
return false;
}
bool
isAdminCmd(std::string const& method, boost::json::object const& request)
{
if (method == JS(ledger)) {
auto const requestStr = boost::json::serialize(request);
Json::Value jv;
Json::Reader{}.parse(requestStr, jv);
// rippled considers string/non-zero int/non-empty array/ non-empty json as true.
// Use rippled's API asBool to get the same result.
// https://github.com/XRPLF/rippled/issues/5119
auto const isFieldSet = [&jv](auto const field) { return jv.isMember(field) and jv[field].asBool(); };
// According to doc
// https://xrpl.org/docs/references/http-websocket-apis/public-api-methods/ledger-methods/ledger,
// full/accounts/type are admin only, but type only works when full/accounts are set, so we don't need to check
// type.
if (isFieldSet(JS(full)) or isFieldSet(JS(accounts)))
return true;
}
if (method == JS(feature) and request.contains(JS(vetoed)))
return true;
return false;
}
std::variant<ripple::uint256, Status>
getNFTID(boost::json::object const& request)
{

View File

@@ -557,6 +557,16 @@ parseIssue(boost::json::object const& issue);
bool
specifiesCurrentOrClosedLedger(boost::json::object const& request);
/**
* @brief Check whether a request requires administrative privileges on rippled side.
*
* @param method The method name to check
* @param request The request to check
* @return true if the request requires ADMIN role
*/
bool
isAdminCmd(std::string const& method, boost::json::object const& request);
/**
* @brief Get the NFTID from the request
*

View File

@@ -200,15 +200,13 @@ CustomValidator CustomValidators::SubscribeStreamValidator =
"ledger", "transactions", "transactions_proposed", "book_changes", "manifests", "validations"
};
static std::unordered_set<std::string> const reportingNotSupportStreams = {
"peer_status", "consensus", "server"
};
static std::unordered_set<std::string> const notSupportStreams = {"peer_status", "consensus", "server"};
for (auto const& v : value.as_array()) {
if (!v.is_string())
return Error{Status{RippledError::rpcINVALID_PARAMS, "streamNotString"}};
if (reportingNotSupportStreams.contains(boost::json::value_to<std::string>(v)))
return Error{Status{RippledError::rpcREPORTING_UNSUPPORTED}};
if (notSupportStreams.contains(boost::json::value_to<std::string>(v)))
return Error{Status{RippledError::rpcNOT_SUPPORTED}};
if (not validStreams.contains(boost::json::value_to<std::string>(v)))
return Error{Status{RippledError::rpcSTREAM_MALFORMED}};

View File

@@ -60,10 +60,6 @@ public:
if (ctx.method == "subscribe" || ctx.method == "unsubscribe")
return false;
// Disallow forwarding of the admin api, only user api is allowed for security reasons.
if (ctx.method == "feature" and request.contains("vetoed"))
return false;
if (handlerProvider_->isClioOnly(ctx.method))
return false;
@@ -73,6 +69,9 @@ public:
if (specifiesCurrentOrClosedLedger(request))
return true;
if (isForcedForward(ctx))
return true;
auto const checkAccountInfoForward = [&]() {
return ctx.method == "account_info" and request.contains("queue") and request.at("queue").is_bool() and
request.at("queue").as_bool();
@@ -142,6 +141,14 @@ private:
{
return handlerProvider_->contains(method) || isProxied(method);
}
bool
isForcedForward(web::Context const& ctx) const
{
static constexpr auto FORCE_FORWARD = "force_forward";
return ctx.isAdmin and ctx.params.contains(FORCE_FORWARD) and ctx.params.at(FORCE_FORWARD).is_bool() and
ctx.params.at(FORCE_FORWARD).as_bool();
}
};
} // namespace rpc::impl

View File

@@ -78,10 +78,17 @@ AccountNFTsHandler::process(AccountNFTsHandler::Input input, Context const& ctx)
input.marker ? ripple::uint256{input.marker->c_str()} : ripple::keylet::nftpage_max(*accountID).key;
auto const blob = sharedPtrBackend_->fetchLedgerObject(pageKey, lgrInfo.seq, ctx.yield);
if (!blob)
if (!blob) {
if (input.marker.has_value())
return Error{Status{RippledError::rpcINVALID_PARAMS, "Marker field does not match any valid Page ID"}};
return response;
}
std::optional<ripple::SLE const> page{ripple::SLE{ripple::SerialIter{blob->data(), blob->size()}, pageKey}};
if (page->getType() != ripple::ltNFTOKEN_PAGE)
return Error{Status{RippledError::rpcINVALID_PARAMS, "Marker matches Page ID from another Account"}};
auto numPages = 0u;
while (page) {

View File

@@ -25,9 +25,12 @@ target_sources(
TimeUtils.cpp
TxUtils.cpp
LedgerUtils.cpp
newconfig/ConfigDefinition.cpp
newconfig/ObjectView.cpp
newconfig/Array.cpp
newconfig/ArrayView.cpp
newconfig/ConfigConstraints.cpp
newconfig/ConfigDefinition.cpp
newconfig/ConfigFileJson.cpp
newconfig/ObjectView.cpp
newconfig/ValueView.cpp
)

View File

@@ -369,7 +369,7 @@ public:
* @brief Block until all operations are completed
*/
void
join() noexcept
join() const noexcept
{
context_.join();
}

View File

@@ -0,0 +1,84 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/newconfig/Array.hpp"
#include "util/Assert.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/Types.hpp"
#include <cstddef>
#include <optional>
#include <string_view>
#include <utility>
#include <vector>
namespace util::config {
Array::Array(ConfigValue arg) : itemPattern_{std::move(arg)}
{
}
std::optional<Error>
Array::addValue(Value value, std::optional<std::string_view> key)
{
auto const& configValPattern = itemPattern_;
auto const constraint = configValPattern.getConstraint();
auto newElem = constraint.has_value() ? ConfigValue{configValPattern.type()}.withConstraint(constraint->get())
: ConfigValue{configValPattern.type()};
if (auto const maybeError = newElem.setValue(value, key); maybeError.has_value())
return maybeError;
elements_.emplace_back(std::move(newElem));
return std::nullopt;
}
size_t
Array::size() const
{
return elements_.size();
}
ConfigValue const&
Array::at(std::size_t idx) const
{
ASSERT(idx < elements_.size(), "Index is out of scope");
return elements_[idx];
}
ConfigValue const&
Array::getArrayPattern() const
{
return itemPattern_;
}
std::vector<ConfigValue>::const_iterator
Array::begin() const
{
return elements_.begin();
}
std::vector<ConfigValue>::const_iterator
Array::end() const
{
return elements_.end();
}
} // namespace util::config

View File

@@ -19,47 +19,42 @@
#pragma once
#include "util/Assert.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/ObjectView.hpp"
#include "util/newconfig/ValueView.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/Types.hpp"
#include <cstddef>
#include <iterator>
#include <type_traits>
#include <utility>
#include <optional>
#include <string_view>
#include <vector>
namespace util::config {
/**
* @brief Array definition for Json/Yaml config
* @brief Array definition to store multiple values provided by the user from Json/Yaml
*
* Used in ClioConfigDefinition to represent multiple potential values (like whitelist)
* Is constructed with only 1 element which states which type/constraint must every element
* In the array satisfy
*/
class Array {
public:
/**
* @brief Constructs an Array with the provided arguments
* @brief Constructs an Array with provided Arg
*
* @tparam Args Types of the arguments
* @param args Arguments to initialize the elements of the Array
* @param arg Argument to set the type and constraint of ConfigValues in Array
*/
template <typename... Args>
constexpr Array(Args&&... args) : elements_{std::forward<Args>(args)...}
{
}
Array(ConfigValue arg);
/**
* @brief Add ConfigValues to Array class
*
* @param value The ConfigValue to add
* @param key optional string key to include that will show in error message
* @return optional error if adding config value to array fails. nullopt otherwise
*/
void
emplaceBack(ConfigValue value)
{
elements_.push_back(std::move(value));
}
std::optional<Error>
addValue(Value value, std::optional<std::string_view> key = std::nullopt);
/**
* @brief Returns the number of values stored in the Array
@@ -67,10 +62,7 @@ public:
* @return Number of values stored in the Array
*/
[[nodiscard]] size_t
size() const
{
return elements_.size();
}
size() const;
/**
* @brief Returns the ConfigValue at the specified index
@@ -79,13 +71,35 @@ public:
* @return ConfigValue at the specified index
*/
[[nodiscard]] ConfigValue const&
at(std::size_t idx) const
{
ASSERT(idx < elements_.size(), "index is out of scope");
return elements_[idx];
}
at(std::size_t idx) const;
/**
* @brief Returns the ConfigValue that defines the type/constraint every
* ConfigValue must follow in Array
*
* @return The item_pattern
*/
[[nodiscard]] ConfigValue const&
getArrayPattern() const;
/**
* @brief Returns an iterator to the beginning of the ConfigValue vector.
*
* @return A constant iterator to the beginning of the vector.
*/
[[nodiscard]] std::vector<ConfigValue>::const_iterator
begin() const;
/**
* @brief Returns an iterator to the end of the ConfigValue vector.
*
* @return A constant iterator to the end of the vector.
*/
[[nodiscard]] std::vector<ConfigValue>::const_iterator
end() const;
private:
ConfigValue itemPattern_;
std::vector<ConfigValue> elements_;
};

View File

@@ -0,0 +1,103 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/newconfig/ConfigConstraints.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/Types.hpp"
#include <cstdint>
#include <optional>
#include <regex>
#include <stdexcept>
#include <string>
#include <variant>
namespace util::config {
std::optional<Error>
PortConstraint::checkTypeImpl(Value const& port) const
{
if (!(std::holds_alternative<int64_t>(port) || std::holds_alternative<std::string>(port)))
return Error{"Port must be a string or integer"};
return std::nullopt;
}
std::optional<Error>
PortConstraint::checkValueImpl(Value const& port) const
{
uint32_t p = 0;
if (std::holds_alternative<std::string>(port)) {
try {
p = static_cast<uint32_t>(std::stoi(std::get<std::string>(port)));
} catch (std::invalid_argument const& e) {
return Error{"Port string must be an integer."};
}
} else {
p = static_cast<uint32_t>(std::get<int64_t>(port));
}
if (p >= portMin && p <= portMax)
return std::nullopt;
return Error{"Port does not satisfy the constraint bounds"};
}
std::optional<Error>
ValidIPConstraint::checkTypeImpl(Value const& ip) const
{
if (!std::holds_alternative<std::string>(ip))
return Error{"Ip value must be a string"};
return std::nullopt;
}
std::optional<Error>
ValidIPConstraint::checkValueImpl(Value const& ip) const
{
if (std::get<std::string>(ip) == "localhost")
return std::nullopt;
static std::regex const ipv4(
R"(^((25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])$)"
);
static std::regex const ip_url(
R"(^((http|https):\/\/)?((([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,6})|(((25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])))(:\d{1,5})?(\/[^\s]*)?$)"
);
if (std::regex_match(std::get<std::string>(ip), ipv4) || std::regex_match(std::get<std::string>(ip), ip_url))
return std::nullopt;
return Error{"Ip is not a valid ip address"};
}
std::optional<Error>
PositiveDouble::checkTypeImpl(Value const& num) const
{
if (!(std::holds_alternative<double>(num) || std::holds_alternative<int64_t>(num)))
return Error{"Double number must be of type int or double"};
return std::nullopt;
}
std::optional<Error>
PositiveDouble::checkValueImpl(Value const& num) const
{
if (std::get<double>(num) >= 0)
return std::nullopt;
return Error{"Double number must be greater than 0"};
}
} // namespace util::config

View File

@@ -0,0 +1,362 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "rpc/common/APIVersion.hpp"
#include "util/log/Logger.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/Types.hpp"
#include <fmt/core.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <optional>
#include <string>
#include <string_view>
#include <variant>
namespace util::config {
class ValueView;
class ConfigValue;
/**
* @brief specific values that are accepted for logger levels in config.
*/
static constexpr std::array<char const*, 7> LOG_LEVELS = {
"trace",
"debug",
"info",
"warning",
"error",
"fatal",
"count",
};
/**
* @brief specific values that are accepted for logger tag style in config.
*/
static constexpr std::array<char const*, 5> LOG_TAGS = {
"int",
"uint",
"null",
"none",
"uuid",
};
/**
* @brief specific values that are accepted for cache loading in config.
*/
static constexpr std::array<char const*, 3> LOAD_CACHE_MODE = {
"sync",
"async",
"none",
};
/**
* @brief specific values that are accepted for database type in config.
*/
static constexpr std::array<char const*, 1> DATABASE_TYPE = {"cassandra"};
/**
* @brief An interface to enforce constraints on certain values within ClioConfigDefinition.
*/
class Constraint {
public:
constexpr virtual ~Constraint() noexcept = default;
/**
* @brief Check if the value meets the specific constraint.
*
* @param val The value to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]]
std::optional<Error>
checkConstraint(Value const& val) const
{
if (auto const maybeError = checkTypeImpl(val); maybeError.has_value())
return maybeError;
return checkValueImpl(val);
}
protected:
/**
* @brief Creates an error message for all constraints that must satisfy certain hard-coded values.
*
* @tparam arrSize, the size of the array of hardcoded values
* @param key The key to the value
* @param value The value the user provided
* @param arr The array with hard-coded values to add to error message
* @return The error message specifying what the value of key must be
*/
template <std::size_t arrSize>
constexpr std::string
makeErrorMsg(std::string_view key, Value const& value, std::array<char const*, arrSize> arr) const
{
// Extract the value from the variant
auto const valueStr = std::visit([](auto const& v) { return fmt::format("{}", v); }, value);
// Create the error message
return fmt::format(
R"(You provided value "{}". Key "{}"'s value must be one of the following: {})",
valueStr,
key,
fmt::join(arr, ", ")
);
}
/**
* @brief Check if the value is of a correct type for the constraint.
*
* @param val The value type to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
virtual std::optional<Error>
checkTypeImpl(Value const& val) const = 0;
/**
* @brief Check if the value is within the constraint.
*
* @param val The value type to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
virtual std::optional<Error>
checkValueImpl(Value const& val) const = 0;
};
/**
* @brief A constraint to ensure the port number is within a valid range.
*/
class PortConstraint final : public Constraint {
public:
constexpr ~PortConstraint() noexcept override = default;
private:
/**
* @brief Check if the type of the value is correct for this specific constraint.
*
* @param port The type to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkTypeImpl(Value const& port) const override;
/**
* @brief Check if the value is within the constraint.
*
* @param port The value to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkValueImpl(Value const& port) const override;
static constexpr uint32_t portMin = 1;
static constexpr uint32_t portMax = 65535;
};
/**
* @brief A constraint to ensure the IP address is valid.
*/
class ValidIPConstraint final : public Constraint {
public:
constexpr ~ValidIPConstraint() noexcept override = default;
private:
/**
* @brief Check if the type of the value is correct for this specific constraint.
*
* @param ip The type to be checked.
* @return An optional Error if the constraint is not met, std::nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkTypeImpl(Value const& ip) const override;
/**
* @brief Check if the value is within the constraint.
*
* @param ip The value to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkValueImpl(Value const& ip) const override;
};
/**
* @brief A constraint class to ensure the provided value is one of the specified values in an array.
*
* @tparam arrSize The size of the array containing the valid values for the constraint
*/
template <std::size_t arrSize>
class OneOf final : public Constraint {
public:
/**
* @brief Constructs a constraint where the value must be one of the values in the provided array.
*
* @param key The key of the ConfigValue that has this constraint
* @param arr The value that has this constraint must be of the values in arr
*/
constexpr OneOf(std::string_view key, std::array<char const*, arrSize> arr) : key_{key}, arr_{arr}
{
}
constexpr ~OneOf() noexcept override = default;
private:
/**
* @brief Check if the type of the value is correct for this specific constraint.
*
* @param val The type to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkTypeImpl(Value const& val) const override
{
if (!std::holds_alternative<std::string>(val))
return Error{fmt::format(R"(Key "{}"'s value must be a string)", key_)};
return std::nullopt;
}
/**
* @brief Check if the value matches one of the value in the provided array
*
* @param val The value to check
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkValueImpl(Value const& val) const override
{
namespace rg = std::ranges;
auto const check = [&val](std::string_view name) { return std::get<std::string>(val) == name; };
if (rg::any_of(arr_, check))
return std::nullopt;
return Error{makeErrorMsg(key_, val, arr_)};
}
std::string_view key_;
std::array<char const*, arrSize> arr_;
};
/**
* @brief A constraint class to ensure an integer value is between two numbers (inclusive)
*/
template <typename numType>
class NumberValueConstraint final : public Constraint {
public:
/**
* @brief Constructs a constraint where the number must be between min_ and max_.
*
* @param min the minimum number it can be to satisfy this constraint
* @param max the maximum number it can be to satisfy this constraint
*/
constexpr NumberValueConstraint(numType min, numType max) : min_{min}, max_{max}
{
}
constexpr ~NumberValueConstraint() noexcept override = default;
private:
/**
* @brief Check if the type of the value is correct for this specific constraint.
*
* @param num The type to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkTypeImpl(Value const& num) const override
{
if (!std::holds_alternative<int64_t>(num))
return Error{"Number must be of type integer"};
return std::nullopt;
}
/**
* @brief Check if the number is positive.
*
* @param num The number to check
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkValueImpl(Value const& num) const override
{
auto const numValue = std::get<int64_t>(num);
if (numValue >= static_cast<int64_t>(min_) && numValue <= static_cast<int64_t>(max_))
return std::nullopt;
return Error{fmt::format("Number must be between {} and {}", min_, max_)};
}
numType min_;
numType max_;
};
/**
* @brief A constraint to ensure a double number is positive
*/
class PositiveDouble final : public Constraint {
public:
constexpr ~PositiveDouble() noexcept override = default;
private:
/**
* @brief Check if the type of the value is correct for this specific constraint.
*
* @param num The type to be checked
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkTypeImpl(Value const& num) const override;
/**
* @brief Check if the number is positive.
*
* @param num The number to check
* @return An Error object if the constraint is not met, nullopt otherwise
*/
[[nodiscard]] std::optional<Error>
checkValueImpl(Value const& num) const override;
};
static constinit PortConstraint validatePort{};
static constinit ValidIPConstraint validateIP{};
static constinit OneOf validateChannelName{"channel", Logger::CHANNELS};
static constinit OneOf validateLogLevelName{"log_level", LOG_LEVELS};
static constinit OneOf validateCassandraName{"database.type", DATABASE_TYPE};
static constinit OneOf validateLoadMode{"cache.load", LOAD_CACHE_MODE};
static constinit OneOf validateLogTag{"log_tag_style", LOG_TAGS};
static constinit PositiveDouble validatePositiveDouble{};
static constinit NumberValueConstraint<uint16_t> validateUint16{
std::numeric_limits<uint16_t>::min(),
std::numeric_limits<uint16_t>::max()
};
static constinit NumberValueConstraint<uint32_t> validateUint32{
std::numeric_limits<uint32_t>::min(),
std::numeric_limits<uint32_t>::max()
};
static constinit NumberValueConstraint<uint32_t> validateApiVersion{rpc::API_VERSION_MIN, rpc::API_VERSION_MAX};
} // namespace util::config

View File

@@ -20,10 +20,15 @@
#include "util/newconfig/ConfigDefinition.hpp"
#include "util/Assert.hpp"
#include "util/OverloadSet.hpp"
#include "util/newconfig/Array.hpp"
#include "util/newconfig/ArrayView.hpp"
#include "util/newconfig/ConfigConstraints.hpp"
#include "util/newconfig/ConfigFileInterface.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/ObjectView.hpp"
#include "util/newconfig/Types.hpp"
#include "util/newconfig/ValueView.hpp"
#include <fmt/core.h>
@@ -38,6 +43,7 @@
#include <thread>
#include <utility>
#include <variant>
#include <vector>
namespace util::config {
/**
@@ -47,62 +53,76 @@ namespace util::config {
* without default values must be present in the user's config file.
*/
static ClioConfigDefinition ClioConfig = ClioConfigDefinition{
{{"database.type", ConfigValue{ConfigType::String}.defaultValue("cassandra")},
{{"database.type", ConfigValue{ConfigType::String}.defaultValue("cassandra").withConstraint(validateCassandraName)},
{"database.cassandra.contact_points", ConfigValue{ConfigType::String}.defaultValue("localhost")},
{"database.cassandra.port", ConfigValue{ConfigType::Integer}},
{"database.cassandra.port", ConfigValue{ConfigType::Integer}.withConstraint(validatePort)},
{"database.cassandra.keyspace", ConfigValue{ConfigType::String}.defaultValue("clio")},
{"database.cassandra.replication_factor", ConfigValue{ConfigType::Integer}.defaultValue(3u)},
{"database.cassandra.table_prefix", ConfigValue{ConfigType::String}.defaultValue("table_prefix")},
{"database.cassandra.max_write_requests_outstanding", ConfigValue{ConfigType::Integer}.defaultValue(10'000)},
{"database.cassandra.max_read_requests_outstanding", ConfigValue{ConfigType::Integer}.defaultValue(100'000)},
{"database.cassandra.max_write_requests_outstanding",
ConfigValue{ConfigType::Integer}.defaultValue(10'000).withConstraint(validateUint32)},
{"database.cassandra.max_read_requests_outstanding",
ConfigValue{ConfigType::Integer}.defaultValue(100'000).withConstraint(validateUint32)},
{"database.cassandra.threads",
ConfigValue{ConfigType::Integer}.defaultValue(static_cast<uint32_t>(std::thread::hardware_concurrency()))},
{"database.cassandra.core_connections_per_host", ConfigValue{ConfigType::Integer}.defaultValue(1)},
{"database.cassandra.queue_size_io", ConfigValue{ConfigType::Integer}.optional()},
{"database.cassandra.write_batch_size", ConfigValue{ConfigType::Integer}.defaultValue(20)},
{"etl_source.[].ip", Array{ConfigValue{ConfigType::String}.optional()}},
{"etl_source.[].ws_port", Array{ConfigValue{ConfigType::String}.optional().min(1).max(65535)}},
{"etl_source.[].grpc_port", Array{ConfigValue{ConfigType::String}.optional().min(1).max(65535)}},
{"forwarding.cache_timeout", ConfigValue{ConfigType::Double}.defaultValue(0.0)},
{"forwarding.request_timeout", ConfigValue{ConfigType::Double}.defaultValue(10.0)},
ConfigValue{ConfigType::Integer}
.defaultValue(static_cast<uint32_t>(std::thread::hardware_concurrency()))
.withConstraint(validateUint32)},
{"database.cassandra.core_connections_per_host",
ConfigValue{ConfigType::Integer}.defaultValue(1).withConstraint(validateUint16)},
{"database.cassandra.queue_size_io", ConfigValue{ConfigType::Integer}.optional().withConstraint(validateUint16)},
{"database.cassandra.write_batch_size",
ConfigValue{ConfigType::Integer}.defaultValue(20).withConstraint(validateUint16)},
{"etl_source.[].ip", Array{ConfigValue{ConfigType::String}.withConstraint(validateIP)}},
{"etl_source.[].ws_port", Array{ConfigValue{ConfigType::String}.withConstraint(validatePort)}},
{"etl_source.[].grpc_port", Array{ConfigValue{ConfigType::String}.withConstraint(validatePort)}},
{"forwarding.cache_timeout",
ConfigValue{ConfigType::Double}.defaultValue(0.0).withConstraint(validatePositiveDouble)},
{"forwarding.request_timeout",
ConfigValue{ConfigType::Double}.defaultValue(10.0).withConstraint(validatePositiveDouble)},
{"dos_guard.whitelist.[]", Array{ConfigValue{ConfigType::String}}},
{"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}.defaultValue(1000'000)},
{"dos_guard.max_connections", ConfigValue{ConfigType::Integer}.defaultValue(20)},
{"dos_guard.max_requests", ConfigValue{ConfigType::Integer}.defaultValue(20)},
{"dos_guard.sweep_interval", ConfigValue{ConfigType::Double}.defaultValue(1.0)},
{"cache.peers.[].ip", Array{ConfigValue{ConfigType::String}}},
{"cache.peers.[].port", Array{ConfigValue{ConfigType::String}}},
{"server.ip", ConfigValue{ConfigType::String}},
{"server.port", ConfigValue{ConfigType::Integer}},
{"server.max_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(0)},
{"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}.defaultValue(1000'000).withConstraint(validateUint32)},
{"dos_guard.max_connections", ConfigValue{ConfigType::Integer}.defaultValue(20).withConstraint(validateUint32)},
{"dos_guard.max_requests", ConfigValue{ConfigType::Integer}.defaultValue(20).withConstraint(validateUint32)},
{"dos_guard.sweep_interval",
ConfigValue{ConfigType::Double}.defaultValue(1.0).withConstraint(validatePositiveDouble)},
{"cache.peers.[].ip", Array{ConfigValue{ConfigType::String}.withConstraint(validateIP)}},
{"cache.peers.[].port", Array{ConfigValue{ConfigType::String}.withConstraint(validatePort)}},
{"server.ip", ConfigValue{ConfigType::String}.withConstraint(validateIP)},
{"server.port", ConfigValue{ConfigType::Integer}.withConstraint(validatePort)},
{"server.workers", ConfigValue{ConfigType::Integer}.withConstraint(validateUint32)},
{"server.max_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint32)},
{"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
{"server.admin_password", ConfigValue{ConfigType::String}.optional()},
{"prometheus.enabled", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
{"prometheus.compress_reply", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
{"io_threads", ConfigValue{ConfigType::Integer}.defaultValue(2)},
{"cache.num_diffs", ConfigValue{ConfigType::Integer}.defaultValue(32)},
{"cache.num_markers", ConfigValue{ConfigType::Integer}.defaultValue(48)},
{"cache.num_cursors_from_diff", ConfigValue{ConfigType::Integer}.defaultValue(0)},
{"cache.num_cursors_from_account", ConfigValue{ConfigType::Integer}.defaultValue(0)},
{"cache.page_fetch_size", ConfigValue{ConfigType::Integer}.defaultValue(512)},
{"cache.load", ConfigValue{ConfigType::String}.defaultValue("async")},
{"log_channels.[].channel", Array{ConfigValue{ConfigType::String}.optional()}},
{"log_channels.[].log_level", Array{ConfigValue{ConfigType::String}.optional()}},
{"log_level", ConfigValue{ConfigType::String}.defaultValue("info")},
{"io_threads", ConfigValue{ConfigType::Integer}.defaultValue(2).withConstraint(validateUint16)},
{"cache.num_diffs", ConfigValue{ConfigType::Integer}.defaultValue(32).withConstraint(validateUint16)},
{"cache.num_markers", ConfigValue{ConfigType::Integer}.defaultValue(48).withConstraint(validateUint16)},
{"cache.num_cursors_from_diff", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint16)},
{"cache.num_cursors_from_account", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint16)
},
{"cache.page_fetch_size", ConfigValue{ConfigType::Integer}.defaultValue(512).withConstraint(validateUint16)},
{"cache.load", ConfigValue{ConfigType::String}.defaultValue("async").withConstraint(validateLoadMode)},
{"log_channels.[].channel", Array{ConfigValue{ConfigType::String}.optional().withConstraint(validateChannelName)}},
{"log_channels.[].log_level",
Array{ConfigValue{ConfigType::String}.optional().withConstraint(validateLogLevelName)}},
{"log_level", ConfigValue{ConfigType::String}.defaultValue("info").withConstraint(validateLogLevelName)},
{"log_format",
ConfigValue{ConfigType::String}.defaultValue(
R"(%TimeStamp% (%SourceLocation%) [%ThreadID%] %Channel%:%Severity% %Message%)"
)},
{"log_to_console", ConfigValue{ConfigType::Boolean}.defaultValue(false)},
{"log_directory", ConfigValue{ConfigType::String}.optional()},
{"log_rotation_size", ConfigValue{ConfigType::Integer}.defaultValue(2048)},
{"log_directory_max_size", ConfigValue{ConfigType::Integer}.defaultValue(50 * 1024)},
{"log_rotation_hour_interval", ConfigValue{ConfigType::Integer}.defaultValue(12)},
{"log_tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
{"extractor_threads", ConfigValue{ConfigType::Integer}.defaultValue(2u)},
{"log_rotation_size", ConfigValue{ConfigType::Integer}.defaultValue(2048u).withConstraint(validateUint32)},
{"log_directory_max_size",
ConfigValue{ConfigType::Integer}.defaultValue(50u * 1024u).withConstraint(validateUint32)},
{"log_rotation_hour_interval", ConfigValue{ConfigType::Integer}.defaultValue(12).withConstraint(validateUint32)},
{"log_tag_style", ConfigValue{ConfigType::String}.defaultValue("uint").withConstraint(validateLogTag)},
{"extractor_threads", ConfigValue{ConfigType::Integer}.defaultValue(2u).withConstraint(validateUint32)},
{"read_only", ConfigValue{ConfigType::Boolean}.defaultValue(false)},
{"txn_threshold", ConfigValue{ConfigType::Integer}.defaultValue(0)},
{"start_sequence", ConfigValue{ConfigType::String}.optional()},
{"finish_sequence", ConfigValue{ConfigType::String}.optional()},
{"txn_threshold", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint16)},
{"start_sequence", ConfigValue{ConfigType::Integer}.optional().withConstraint(validateUint32)},
{"finish_sequence", ConfigValue{ConfigType::Integer}.optional().withConstraint(validateUint32)},
{"ssl_cert_file", ConfigValue{ConfigType::String}.optional()},
{"ssl_key_file", ConfigValue{ConfigType::String}.optional()},
{"api_version.min", ConfigValue{ConfigType::Integer}},
@@ -113,7 +133,7 @@ ClioConfigDefinition::ClioConfigDefinition(std::initializer_list<KeyValuePair> p
{
for (auto const& [key, value] : pair) {
if (key.contains("[]"))
ASSERT(std::holds_alternative<Array>(value), "Value must be array if key has \"[]\"");
ASSERT(std::holds_alternative<Array>(value), R"(Value must be array if key has "[]")");
map_.insert({key, value});
}
}
@@ -206,4 +226,51 @@ ClioConfigDefinition::arraySize(std::string_view prefix) const
std::unreachable();
}
std::optional<std::vector<Error>>
ClioConfigDefinition::parse(ConfigFileInterface const& config)
{
std::vector<Error> listOfErrors;
for (auto& [key, value] : map_) {
// if key doesn't exist in user config, makes sure it is marked as ".optional()" or has ".defaultValue()"" in
// ClioConfigDefitinion above
if (!config.containsKey(key)) {
if (std::holds_alternative<ConfigValue>(value)) {
if (!(std::get<ConfigValue>(value).isOptional() || std::get<ConfigValue>(value).hasValue()))
listOfErrors.emplace_back(key, "key is required in user Config");
} else if (std::holds_alternative<Array>(value)) {
if (!(std::get<Array>(value).getArrayPattern().isOptional()))
listOfErrors.emplace_back(key, "key is required in user Config");
}
continue;
}
ASSERT(
std::holds_alternative<ConfigValue>(value) || std::holds_alternative<Array>(value),
"Value must be of type ConfigValue or Array"
);
std::visit(
util::OverloadSet{// handle the case where the config value is a single element.
// attempt to set the value from the configuration for the specified key.
[&key, &config, &listOfErrors](ConfigValue& val) {
if (auto const maybeError = val.setValue(config.getValue(key), key);
maybeError.has_value())
listOfErrors.emplace_back(maybeError.value());
},
// handle the case where the config value is an array.
// iterate over each provided value in the array and attempt to set it for the key.
[&key, &config, &listOfErrors](Array& arr) {
for (auto const& val : config.getArray(key)) {
if (auto const maybeError = arr.addValue(val, key); maybeError.has_value())
listOfErrors.emplace_back(maybeError.value());
}
}
},
value
);
}
if (!listOfErrors.empty())
return listOfErrors;
return std::nullopt;
}
} // namespace util::config

View File

@@ -24,7 +24,7 @@
#include "util/newconfig/ConfigDescription.hpp"
#include "util/newconfig/ConfigFileInterface.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Errors.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/ObjectView.hpp"
#include "util/newconfig/ValueView.hpp"
@@ -41,6 +41,7 @@
#include <unordered_map>
#include <utility>
#include <variant>
#include <vector>
namespace util::config {
@@ -66,12 +67,13 @@ public:
/**
* @brief Parses the configuration file
*
* Should also check that no extra configuration key/value pairs are present
* Also checks that no extra configuration key/value pairs are present. Adds to list of Errors
* if it does
*
* @param config The configuration file interface
* @return An optional Error object if parsing fails
* @return An optional vector of Error objects stating all the failures if parsing fails
*/
[[nodiscard]] std::optional<Error>
[[nodiscard]] std::optional<std::vector<Error>>
parse(ConfigFileInterface const& config);
/**
@@ -80,9 +82,9 @@ public:
* Should only check for valid values, without populating
*
* @param config The configuration file interface
* @return An optional Error object if validation fails
* @return An optional vector of Error objects stating all the failures if validation fails
*/
[[nodiscard]] std::optional<Error>
[[nodiscard]] std::optional<std::vector<Error>>
validate(ConfigFileInterface const& config) const;
/**

View File

@@ -90,7 +90,9 @@ private:
KV{"server.ip", "IP address of the Clio HTTP server."},
KV{"server.port", "Port number of the Clio HTTP server."},
KV{"server.max_queue_size", "Maximum size of the server's request queue."},
KV{"server.workers", "Maximum number of threads for server to run with."},
KV{"server.local_admin", "Indicates if the server should run with admin privileges."},
KV{"server.admin_password", "Password for Clio admin-only APIs."},
KV{"prometheus.enabled", "Enable or disable Prometheus metrics."},
KV{"prometheus.compress_reply", "Enable or disable compression of Prometheus responses."},
KV{"io_threads", "Number of I/O threads."},

View File

@@ -19,9 +19,8 @@
#pragma once
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Types.hpp"
#include <optional>
#include <string_view>
#include <vector>
@@ -36,31 +35,33 @@ namespace util::config {
class ConfigFileInterface {
public:
virtual ~ConfigFileInterface() = default;
/**
* @brief Parses the provided path of user clio configuration data
*
* @param filePath The path to the Clio Config data
*/
virtual void
parse(std::string_view filePath) = 0;
/**
* @brief Retrieves a configuration value.
* @brief Retrieves the value of configValue.
*
* @param key The key of the configuration value.
* @return An optional containing the configuration value if found, otherwise std::nullopt.
* @param key The key of configuration.
* @return the value assosiated with key.
*/
virtual std::optional<ConfigValue>
virtual Value
getValue(std::string_view key) const = 0;
/**
* @brief Retrieves an array of configuration values.
*
* @param key The key of the configuration array.
* @return An optional containing a vector of configuration values if found, otherwise std::nullopt.
* @return A vector of configuration values if found, otherwise std::nullopt.
*/
virtual std::optional<std::vector<ConfigValue>>
virtual std::vector<Value>
getArray(std::string_view key) const = 0;
/**
* @brief Checks if key exist in configuration file.
*
* @param key The key to search for.
* @return true if key exists in configuration file, false otherwise.
*/
virtual bool
containsKey(std::string_view key) const = 0;
};
} // namespace util::config

View File

@@ -0,0 +1,166 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/newconfig/ConfigFileJson.hpp"
#include "util/Assert.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/Types.hpp"
#include <boost/filesystem/path.hpp>
#include <boost/json/array.hpp>
#include <boost/json/object.hpp>
#include <boost/json/parse.hpp>
#include <boost/json/parse_options.hpp>
#include <boost/json/value.hpp>
#include <fmt/core.h>
#include <cstddef>
#include <exception>
#include <fstream>
#include <ios>
#include <iostream>
#include <ostream>
#include <sstream>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
namespace util::config {
namespace {
/**
* @brief Extracts the value from a JSON object and converts it into the corresponding type.
*
* @param jsonValue The JSON value to extract.
* @return A variant containing the same type corresponding to the extracted value.
*/
[[nodiscard]] Value
extractJsonValue(boost::json::value const& jsonValue)
{
if (jsonValue.is_int64()) {
return jsonValue.as_int64();
}
if (jsonValue.is_string()) {
return jsonValue.as_string().c_str();
}
if (jsonValue.is_bool()) {
return jsonValue.as_bool();
}
if (jsonValue.is_double()) {
return jsonValue.as_double();
}
ASSERT(false, "Json is not of type int, string, bool or double");
std::unreachable();
}
} // namespace
ConfigFileJson::ConfigFileJson(boost::json::object jsonObj)
{
flattenJson(jsonObj, "");
}
std::expected<ConfigFileJson, Error>
ConfigFileJson::make_ConfigFileJson(boost::filesystem::path configFilePath)
{
try {
if (auto const in = std::ifstream(configFilePath.string(), std::ios::in | std::ios::binary); in) {
std::stringstream contents;
contents << in.rdbuf();
auto opts = boost::json::parse_options{};
opts.allow_comments = true;
auto const tempObj = boost::json::parse(contents.str(), {}, opts).as_object();
return ConfigFileJson{tempObj};
}
return std::unexpected<Error>(
Error{fmt::format("Could not open configuration file '{}'", configFilePath.string())}
);
} catch (std::exception const& e) {
return std::unexpected<Error>(Error{fmt::format(
"An error occurred while processing configuration file '{}': {}", configFilePath.string(), e.what()
)});
}
}
Value
ConfigFileJson::getValue(std::string_view key) const
{
auto const jsonValue = jsonObject_.at(key);
auto const value = extractJsonValue(jsonValue);
return value;
}
std::vector<Value>
ConfigFileJson::getArray(std::string_view key) const
{
ASSERT(jsonObject_.at(key).is_array(), "Key {} has value that is not an array", key);
std::vector<Value> configValues;
auto const arr = jsonObject_.at(key).as_array();
for (auto const& item : arr) {
auto const value = extractJsonValue(item);
configValues.emplace_back(value);
}
return configValues;
}
bool
ConfigFileJson::containsKey(std::string_view key) const
{
return jsonObject_.contains(key);
}
void
ConfigFileJson::flattenJson(boost::json::object const& obj, std::string const& prefix)
{
for (auto const& [key, value] : obj) {
std::string const fullKey = prefix.empty() ? std::string(key) : fmt::format("{}.{}", prefix, std::string(key));
// In ClioConfigDefinition, value must be a primitive or array
if (value.is_object()) {
flattenJson(value.as_object(), fullKey);
} else if (value.is_array()) {
auto const& arr = value.as_array();
for (std::size_t i = 0; i < arr.size(); ++i) {
std::string const arrayPrefix = fullKey + ".[]";
if (arr[i].is_object()) {
flattenJson(arr[i].as_object(), arrayPrefix);
} else {
jsonObject_[arrayPrefix] = arr;
}
}
} else {
// if "[]" is present in key, then value must be an array instead of primitive
if (fullKey.contains(".[]") && !jsonObject_.contains(fullKey)) {
boost::json::array newArray;
newArray.emplace_back(value);
jsonObject_[fullKey] = newArray;
} else if (fullKey.contains(".[]") && jsonObject_.contains(fullKey)) {
jsonObject_[fullKey].as_array().emplace_back(value);
} else {
jsonObject_[fullKey] = value;
}
}
}
}
} // namespace util::config

View File

@@ -0,0 +1,100 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "util/newconfig/ConfigFileInterface.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/Types.hpp"
#include <boost/filesystem/path.hpp>
#include <boost/json/object.hpp>
#include <expected>
#include <string>
#include <string_view>
#include <vector>
namespace util::config {
/** @brief Json representation of config */
class ConfigFileJson final : public ConfigFileInterface {
public:
/**
* @brief Construct a new ConfigJson object and stores the values from
* user's config into a json object.
*
* @param jsonObj the Json object to parse; represents user's config
*/
ConfigFileJson(boost::json::object jsonObj);
/**
* @brief Retrieves a configuration value by its key.
*
* @param key The key of the configuration value to retrieve.
* @return A variant containing the same type corresponding to the extracted value.
*/
[[nodiscard]] Value
getValue(std::string_view key) const override;
/**
* @brief Retrieves an array of configuration values by its key.
*
* @param key The key of the configuration array to retrieve.
* @return A vector of variants holding the config values specified by user.
*/
[[nodiscard]] std::vector<Value>
getArray(std::string_view key) const override;
/**
* @brief Checks if the configuration contains a specific key.
*
* @param key The key to check for.
* @return True if the key exists, false otherwise.
*/
[[nodiscard]] bool
containsKey(std::string_view key) const override;
/**
* @brief Creates a new ConfigFileJson by parsing the provided JSON file and
* stores the values in the object.
*
* @param configFilePath The path to the JSON file to be parsed.
* @return A ConfigFileJson object if parsing user file is successful. Error otherwise
*/
[[nodiscard]] static std::expected<ConfigFileJson, Error>
make_ConfigFileJson(boost::filesystem::path configFilePath);
private:
/**
* @brief Recursive function to flatten a JSON object into the same structure as the Clio Config.
*
* The keys will end up having the same naming convensions in Clio Config.
* Other than the keys specified in user Config file, no new keys are created.
*
* @param obj The JSON object to flatten.
* @param prefix The prefix to use for the keys in the flattened object.
*/
void
flattenJson(boost::json::object const& obj, std::string const& prefix);
boost::json::object jsonObject_;
};
} // namespace util::config

View File

@@ -19,15 +19,31 @@
#pragma once
#include <string>
#include "util/newconfig/ConfigFileInterface.hpp"
#include "util/newconfig/Types.hpp"
#include <boost/filesystem/path.hpp>
#include <string_view>
#include <vector>
// TODO: implement when we support yaml
namespace util::config {
/** @brief todo: Will display the different errors when parsing config */
struct Error {
std::string_view key;
std::string_view error;
/** @brief Yaml representation of config */
class ConfigFileYaml final : public ConfigFileInterface {
public:
ConfigFileYaml() = default;
Value
getValue(std::string_view key) const override;
std::vector<Value>
getArray(std::string_view key) const override;
bool
containsKey(std::string_view key) const override;
};
} // namespace util::config

View File

@@ -20,43 +20,23 @@
#pragma once
#include "util/Assert.hpp"
#include "util/UnsupportedType.hpp"
#include "util/OverloadSet.hpp"
#include "util/newconfig/ConfigConstraints.hpp"
#include "util/newconfig/Error.hpp"
#include "util/newconfig/Types.hpp"
#include <fmt/core.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
#include <string>
#include <type_traits>
#include <string_view>
#include <variant>
namespace util::config {
/** @brief Custom clio config types */
enum class ConfigType { Integer, String, Double, Boolean };
/**
* @brief Get the corresponding clio config type
*
* @tparam Type The type to get the corresponding ConfigType for
* @return The corresponding ConfigType
*/
template <typename Type>
constexpr ConfigType
getType()
{
if constexpr (std::is_same_v<Type, int64_t>) {
return ConfigType::Integer;
} else if constexpr (std::is_same_v<Type, std::string>) {
return ConfigType::String;
} else if constexpr (std::is_same_v<Type, double>) {
return ConfigType::Double;
} else if constexpr (std::is_same_v<Type, bool>) {
return ConfigType::Boolean;
} else {
static_assert(util::Unsupported<Type>, "Wrong config type");
}
}
/**
* @brief Represents the config values for Json/Yaml config
*
@@ -65,8 +45,6 @@ getType()
*/
class ConfigValue {
public:
using Type = std::variant<int64_t, std::string, bool, double>;
/**
* @brief Constructor initializing with the config type
*
@@ -83,12 +61,92 @@ public:
* @return Reference to this ConfigValue
*/
[[nodiscard]] ConfigValue&
defaultValue(Type value)
defaultValue(Value value)
{
setValue(value);
auto const err = checkTypeConsistency(type_, value);
ASSERT(!err.has_value(), "{}", err->error);
value_ = value;
return *this;
}
/**
* @brief Sets the value current ConfigValue given by the User's defined value
*
* @param value The value to set
* @param key The Config key associated with the value. Optional to include; Used for debugging message to user.
* @return optional Error if user tries to set a value of wrong type or not within a constraint
*/
[[nodiscard]] std::optional<Error>
setValue(Value value, std::optional<std::string_view> key = std::nullopt)
{
auto err = checkTypeConsistency(type_, value);
if (err.has_value()) {
if (key.has_value())
err->error = fmt::format("{} {}", key.value(), err->error);
return err;
}
if (cons_.has_value()) {
auto constraintCheck = cons_->get().checkConstraint(value);
if (constraintCheck.has_value()) {
if (key.has_value())
constraintCheck->error = fmt::format("{} {}", key.value(), constraintCheck->error);
return constraintCheck;
}
}
value_ = value;
return std::nullopt;
}
/**
* @brief Assigns a constraint to the ConfigValue.
*
* This method associates a specific constraint with the ConfigValue.
* If the ConfigValue already holds a value, the method will check whether
* the value satisfies the given constraint. If the constraint is not satisfied,
* an assertion failure will occur with a detailed error message.
*
* @param cons The constraint to be applied to the ConfigValue.
* @return A reference to the modified ConfigValue object.
*/
[[nodiscard]] constexpr ConfigValue&
withConstraint(Constraint const& cons)
{
cons_ = std::reference_wrapper<Constraint const>(cons);
ASSERT(cons_.has_value(), "Constraint must be defined");
if (value_.has_value()) {
auto const& temp = cons_.value().get();
auto const& result = temp.checkConstraint(value_.value());
if (result.has_value()) {
// useful for specifying clear Error message
std::string type;
std::visit(
util::OverloadSet{
[&type](bool tmp) { type = fmt::format("bool {}", tmp); },
[&type](std::string const& tmp) { type = fmt::format("string {}", tmp); },
[&type](double tmp) { type = fmt::format("double {}", tmp); },
[&type](int64_t tmp) { type = fmt::format("int {}", tmp); }
},
value_.value()
);
ASSERT(false, "Value {} ConfigValue does not satisfy the set Constraint", type);
}
}
return *this;
}
/**
* @brief Retrieves the constraint associated with this ConfigValue, if any.
*
* @return An optional reference to the associated Constraint.
*/
[[nodiscard]] std::optional<std::reference_wrapper<Constraint const>>
getConstraint() const
{
return cons_;
}
/**
* @brief Gets the config type
*
@@ -100,32 +158,6 @@ public:
return type_;
}
/**
* @brief Sets the minimum value for the config
*
* @param min The minimum value
* @return Reference to this ConfigValue
*/
[[nodiscard]] constexpr ConfigValue&
min(std::uint32_t min)
{
min_ = min;
return *this;
}
/**
* @brief Sets the maximum value for the config
*
* @param max The maximum value
* @return Reference to this ConfigValue
*/
[[nodiscard]] constexpr ConfigValue&
max(std::uint32_t max)
{
max_ = max;
return *this;
}
/**
* @brief Sets the config value as optional, meaning the user doesn't have to provide the value in their config
*
@@ -165,7 +197,7 @@ public:
*
* @return Config Value
*/
[[nodiscard]] Type const&
[[nodiscard]] Value const&
getValue() const
{
return value_.value();
@@ -178,39 +210,28 @@ private:
* @param type The config type
* @param value The config value
*/
static void
checkTypeConsistency(ConfigType type, Type value)
static std::optional<Error>
checkTypeConsistency(ConfigType type, Value value)
{
if (std::holds_alternative<std::string>(value)) {
ASSERT(type == ConfigType::String, "Value does not match type string");
} else if (std::holds_alternative<bool>(value)) {
ASSERT(type == ConfigType::Boolean, "Value does not match type boolean");
} else if (std::holds_alternative<double>(value)) {
ASSERT(type == ConfigType::Double, "Value does not match type double");
} else if (std::holds_alternative<int64_t>(value)) {
ASSERT(type == ConfigType::Integer, "Value does not match type integer");
if (type == ConfigType::String && !std::holds_alternative<std::string>(value)) {
return Error{"value does not match type string"};
}
if (type == ConfigType::Boolean && !std::holds_alternative<bool>(value)) {
return Error{"value does not match type boolean"};
}
/**
* @brief Sets the value for the config
*
* @param value The value to set
* @return The value that was set
*/
Type
setValue(Type value)
{
checkTypeConsistency(type_, value);
value_ = value;
return value;
if (type == ConfigType::Double && !std::holds_alternative<double>(value)) {
return Error{"value does not match type double"};
}
if (type == ConfigType::Integer && !std::holds_alternative<int64_t>(value)) {
return Error{"value does not match type integer"};
}
return std::nullopt;
}
ConfigType type_{};
bool optional_{false};
std::optional<Type> value_;
std::optional<std::uint32_t> min_;
std::optional<std::uint32_t> max_;
std::optional<Value> value_;
std::optional<std::reference_wrapper<Constraint const>> cons_;
};
} // namespace util::config

View File

@@ -0,0 +1,57 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include <fmt/core.h>
#include <string>
#include <string_view>
#include <utility>
namespace util::config {
/** @brief Displays the different errors when parsing user config */
struct Error {
/**
* @brief Constructs an Error with a custom error message.
*
* @param err the error message to display to users.
*/
Error(std::string err) : error{std::move(err)}
{
}
/**
* @brief Constructs an Error with a custom error message.
*
* @param key the key associated with the error.
* @param err the error message to display to users.
*/
Error(std::string_view key, std::string_view err)
: error{
fmt::format("{} {}", key, err),
}
{
}
std::string error;
};
} // namespace util::config

View File

@@ -0,0 +1,60 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "util/UnsupportedType.hpp"
#include <cstdint>
#include <string>
#include <type_traits>
#include <variant>
namespace util::config {
/** @brief Custom clio config types */
enum class ConfigType { Integer, String, Double, Boolean };
/** @brief Represents the supported Config Values */
using Value = std::variant<int64_t, std::string, bool, double>;
/**
* @brief Get the corresponding clio config type
*
* @tparam Type The type to get the corresponding ConfigType for
* @return The corresponding ConfigType
*/
template <typename Type>
constexpr ConfigType
getType()
{
if constexpr (std::is_same_v<Type, int64_t>) {
return ConfigType::Integer;
} else if constexpr (std::is_same_v<Type, std::string>) {
return ConfigType::String;
} else if constexpr (std::is_same_v<Type, double>) {
return ConfigType::Double;
} else if constexpr (std::is_same_v<Type, bool>) {
return ConfigType::Boolean;
} else {
static_assert(util::Unsupported<Type>, "Wrong config type");
}
}
} // namespace util::config

View File

@@ -21,6 +21,7 @@
#include "util/Assert.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Types.hpp"
#include <cstdint>
#include <string>
@@ -55,9 +56,9 @@ double
ValueView::asDouble() const
{
if (configVal_.get().hasValue()) {
if (type() == ConfigType::Double) {
if (type() == ConfigType::Double)
return std::get<double>(configVal_.get().getValue());
}
if (type() == ConfigType::Integer)
return static_cast<double>(std::get<int64_t>(configVal_.get().getValue()));
}

View File

@@ -21,6 +21,7 @@
#include "util/Assert.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Types.hpp"
#include <fmt/core.h>
@@ -84,7 +85,7 @@ public:
return static_cast<T>(val);
}
}
ASSERT(false, "Value view is not of any Int type");
ASSERT(false, "Value view is not of Int type");
return 0;
}

View File

@@ -39,8 +39,10 @@
#include <boost/beast/websocket/stream_base.hpp>
#include <boost/system/errc.hpp>
#include <atomic>
#include <chrono>
#include <expected>
#include <memory>
#include <optional>
#include <string>
#include <utility>
@@ -123,15 +125,20 @@ private:
static void
withTimeout(Operation&& operation, boost::asio::yield_context yield, std::chrono::steady_clock::duration timeout)
{
auto isCompleted = std::make_shared<bool>(false);
boost::asio::cancellation_signal cancellationSignal;
auto cyield = boost::asio::bind_cancellation_slot(cancellationSignal.slot(), yield);
boost::asio::steady_timer timer{boost::asio::get_associated_executor(cyield), timeout};
timer.async_wait([&cancellationSignal](boost::system::error_code errorCode) {
if (!errorCode)
// The timer below can be called with no error code even if the operation is completed before the timeout, so we
// need an additional flag here
timer.async_wait([&cancellationSignal, isCompleted](boost::system::error_code errorCode) {
if (!errorCode and not *isCompleted)
cancellationSignal.emit(boost::asio::cancellation_type::terminal);
});
operation(cyield);
*isCompleted = true;
}
static boost::system::error_code

View File

@@ -43,9 +43,11 @@
#include <algorithm>
#include <expected>
#include <functional>
#include <iterator>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
@@ -75,6 +77,14 @@ TestWsConnection::send(std::string const& message, boost::asio::yield_context yi
return std::nullopt;
}
void
TestWsConnection::sendPing(boost::beast::websocket::ping_data const& data, boost::asio::yield_context yield)
{
boost::beast::error_code errorCode;
ws_.async_ping(data, yield[errorCode]);
[&]() { ASSERT_FALSE(errorCode) << errorCode.message(); }();
}
std::optional<std::string>
TestWsConnection::receive(boost::asio::yield_context yield)
{
@@ -105,6 +115,20 @@ TestWsConnection::headers() const
return headers_;
}
void
TestWsConnection::setControlFrameCallback(
std::function<void(boost::beast::websocket::frame_type, std::string_view)> callback
)
{
ws_.control_callback(std::move(callback));
}
void
TestWsConnection::resetControlFrameCallback()
{
ws_.control_callback();
}
TestWsServer::TestWsServer(asio::io_context& context, std::string const& host) : acceptor_(context)
{
auto endpoint = asio::ip::tcp::endpoint(boost::asio::ip::make_address(host), 0);

View File

@@ -25,6 +25,7 @@
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/beast/core/tcp_stream.hpp>
#include <boost/beast/websocket/rfc6455.hpp>
#include <boost/beast/websocket/stream.hpp>
#include <expected>
@@ -54,6 +55,9 @@ public:
std::optional<std::string>
send(std::string const& message, boost::asio::yield_context yield);
void
sendPing(boost::beast::websocket::ping_data const& data, boost::asio::yield_context yield);
// returns nullopt if the connection is closed
std::optional<std::string>
receive(boost::asio::yield_context yield);
@@ -63,6 +67,12 @@ public:
std::vector<util::requests::HttpHeader> const&
headers() const;
void
setControlFrameCallback(std::function<void(boost::beast::websocket::frame_type, std::string_view)> callback);
void
resetControlFrameCallback();
};
using TestWsConnectionPtr = std::unique_ptr<TestWsConnection>;

View File

@@ -20,13 +20,25 @@
#pragma once
#include "util/newconfig/Array.hpp"
#include "util/newconfig/ConfigConstraints.hpp"
#include "util/newconfig/ConfigDefinition.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Types.hpp"
#include <gtest/gtest.h>
using namespace util::config;
/**
* @brief A mock ClioConfigDefinition for testing purposes.
*
* In the actual Clio configuration, arrays typically hold optional values, meaning users are not required to
* provide values for them.
*
* For primitive types (i.e., single specific values), some are mandatory and must be explicitly defined in the
* user's configuration file, including both the key and the corresponding value, while some are optional
*/
inline ClioConfigDefinition
generateConfig()
{
@@ -36,36 +48,115 @@ generateConfig()
{"header.admin", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
{"header.sub.sub2Value", ConfigValue{ConfigType::String}.defaultValue("TSM")},
{"ip", ConfigValue{ConfigType::Double}.defaultValue(444.22)},
{"array.[].sub",
Array{
ConfigValue{ConfigType::Double}.defaultValue(111.11), ConfigValue{ConfigType::Double}.defaultValue(4321.55)
}},
{"array.[].sub2",
Array{
ConfigValue{ConfigType::String}.defaultValue("subCategory"),
ConfigValue{ConfigType::String}.defaultValue("temporary")
}},
{"higher.[].low.section", Array{ConfigValue{ConfigType::String}.defaultValue("true")}},
{"higher.[].low.admin", Array{ConfigValue{ConfigType::Boolean}.defaultValue(false)}},
{"dosguard.whitelist.[]",
Array{
ConfigValue{ConfigType::String}.defaultValue("125.5.5.2"),
ConfigValue{ConfigType::String}.defaultValue("204.2.2.2")
}},
{"dosguard.port", ConfigValue{ConfigType::Integer}.defaultValue(55555)}
{"array.[].sub", Array{ConfigValue{ConfigType::Double}}},
{"array.[].sub2", Array{ConfigValue{ConfigType::String}.optional()}},
{"higher.[].low.section", Array{ConfigValue{ConfigType::String}.withConstraint(validateChannelName)}},
{"higher.[].low.admin", Array{ConfigValue{ConfigType::Boolean}}},
{"dosguard.whitelist.[]", Array{ConfigValue{ConfigType::String}.optional()}},
{"dosguard.port", ConfigValue{ConfigType::Integer}.defaultValue(55555).withConstraint(validatePort)},
{"optional.withDefault", ConfigValue{ConfigType::Double}.defaultValue(0.0).optional()},
{"optional.withNoDefault", ConfigValue{ConfigType::Double}.optional()},
{"requireValue", ConfigValue{ConfigType::String}}
};
}
/* The config definition above would look like this structure in config.json:
"header": {
/* The config definition above would look like this structure in config.json
{
"header": {
"text1": "value",
"port": 123,
"port": 321,
"admin": true,
"sub": {
"sub2Value": "TSM"
}
},
"ip": 444.22,
"array": [
{
"sub": //optional for user to include
"sub2": //optional for user to include
},
],
"higher": [
{
"low": {
"section": //optional for user to include
"admin": //optional for user to include
}
}
],
"dosguard": {
"whitelist": [
// mandatory for user to include
],
"port" : 55555
},
},
"optional" : {
"withDefault" : 0.0,
"withNoDefault" : //optional for user to include
},
"requireValue" : // value must be provided by user
}
*/
/* Used to test overwriting default values in ClioConfigDefinition Above */
constexpr static auto JSONData = R"JSON(
{
"header": {
"text1": "value",
"port": 321,
"admin": false,
"sub": {
"sub2Value": "TSM"
}
},
"array": [
{
"sub": 111.11,
"sub2": "subCategory"
},
{
"sub": 4321.55,
"sub2": "temporary"
},
{
"sub": 5555.44,
"sub2": "london"
}
],
"higher": [
{
"low": {
"section": "WebServer",
"admin": false
}
}
],
"dosguard": {
"whitelist": [
"125.5.5.1", "204.2.2.1"
],
"port" : 44444
},
"optional" : {
"withDefault" : 0.0
},
"requireValue" : "required"
}
)JSON";
/* After parsing jsonValue and populating it into ClioConfig, It will look like this below in json format;
{
"header": {
"text1": "value",
"port": 321,
"admin": false,
"sub": {
"sub2Value": "TSM"
}
},
"ip": 444.22,
"array": [
{
"sub": 111.11,
@@ -74,22 +165,50 @@ generateConfig()
{
"sub": 4321.55,
"sub2": "temporary"
},
{
"sub": 5555.44,
"sub2": "london"
}
],
"higher": [
{
"low": {
"section": "true",
"section": "WebServer",
"admin": false
}
}
],
"dosguard": {
"whitelist": [
"125.5.5.2", "204.2.2.2"
"125.5.5.1", "204.2.2.1"
],
"port" : 55555
"port" : 44444
}
},
"optional" : {
"withDefault" : 0.0
},
"requireValue" : "required"
}
*/
// Invalid Json key/values
constexpr static auto invalidJSONData = R"JSON(
{
"header": {
"port": "999",
"admin": "true"
},
"dosguard": {
"whitelist": [
false
]
},
"idk": true,
"requireValue" : "required",
"optional" : {
"withDefault" : "0.0"
}
}
)JSON";

View File

@@ -133,12 +133,13 @@ target_sources(
web/RPCServerHandlerTests.cpp
web/ServerTests.cpp
# New Config
util/newconfig/ArrayViewTests.cpp
util/newconfig/ObjectViewTests.cpp
util/newconfig/ValueViewTests.cpp
util/newconfig/ArrayTests.cpp
util/newconfig/ConfigValueTests.cpp
util/newconfig/ArrayViewTests.cpp
util/newconfig/ClioConfigDefinitionTests.cpp
util/newconfig/ConfigValueTests.cpp
util/newconfig/ObjectViewTests.cpp
util/newconfig/JsonConfigFileTests.cpp
util/newconfig/ValueViewTests.cpp
)
configure_file(test_data/cert.pem ${CMAKE_BINARY_DIR}/tests/unit/test_data/cert.pem COPYONLY)

View File

@@ -49,7 +49,7 @@ struct AmendmentCenterTest : util::prometheus::WithPrometheus, MockBackendTest,
TEST_F(AmendmentCenterTest, AllAmendmentsFromLibXRPLAreSupported)
{
for (auto const& [name, _] : ripple::allAmendments()) {
ASSERT_TRUE(amendmentCenter.isSupported(name)) << "XRPL amendment not supported by Clio: " << name;
EXPECT_TRUE(amendmentCenter.isSupported(name)) << "XRPL amendment not supported by Clio: " << name;
}
ASSERT_EQ(amendmentCenter.getSupported().size(), ripple::allAmendments().size());

View File

@@ -280,15 +280,12 @@ TEST_F(LoadBalancerOnDisconnectHookTests, source0Disconnects)
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(true));
sourceFactory_.callbacksAt(0).onDisconnect();
sourceFactory_.callbacksAt(0).onDisconnect(true);
}
TEST_F(LoadBalancerOnDisconnectHookTests, source1Disconnects)
{
EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(true));
EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false));
sourceFactory_.callbacksAt(1).onDisconnect();
sourceFactory_.callbacksAt(1).onDisconnect(false);
}
TEST_F(LoadBalancerOnDisconnectHookTests, source0DisconnectsAndConnectsBack)
@@ -297,29 +294,25 @@ TEST_F(LoadBalancerOnDisconnectHookTests, source0DisconnectsAndConnectsBack)
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(true));
sourceFactory_.callbacksAt(0).onDisconnect();
sourceFactory_.callbacksAt(0).onDisconnect(true);
sourceFactory_.callbacksAt(0).onConnect();
}
TEST_F(LoadBalancerOnDisconnectHookTests, source1DisconnectsAndConnectsBack)
{
EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(true));
EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false));
sourceFactory_.callbacksAt(1).onDisconnect();
sourceFactory_.callbacksAt(1).onDisconnect(false);
sourceFactory_.callbacksAt(1).onConnect();
}
TEST_F(LoadBalancerOnConnectHookTests, bothSourcesDisconnectAndConnectBack)
{
EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).Times(2).WillRepeatedly(Return(false));
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false)).Times(2);
EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).Times(2).WillRepeatedly(Return(false));
EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false)).Times(2);
sourceFactory_.callbacksAt(0).onDisconnect();
sourceFactory_.callbacksAt(1).onDisconnect();
EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(false));
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(false));
EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false));
sourceFactory_.callbacksAt(0).onDisconnect(true);
sourceFactory_.callbacksAt(1).onDisconnect(false);
EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(true));
@@ -362,12 +355,7 @@ TEST_F(LoadBalancer3SourcesTests, forwardingUpdate)
sourceFactory_.callbacksAt(1).onConnect();
// Source 0 got disconnected
EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(false));
EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(true));
EXPECT_CALL(sourceFactory_.sourceAt(2), setForwarding(false)); // only source 1 must be forwarding
sourceFactory_.callbacksAt(0).onDisconnect();
sourceFactory_.callbacksAt(0).onDisconnect(false);
}
struct LoadBalancerLoadInitialLedgerTests : LoadBalancerOnConnectHookTests {

View File

@@ -20,27 +20,33 @@
#include "etl/impl/SubscriptionSource.hpp"
#include "util/LoggerFixtures.hpp"
#include "util/MockNetworkValidatedLedgers.hpp"
#include "util/MockPrometheus.hpp"
#include "util/MockSubscriptionManager.hpp"
#include "util/TestWsServer.hpp"
#include "util/prometheus/Gauge.hpp"
#include <boost/asio/io_context.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/json/object.hpp>
#include <boost/json/serialize.hpp>
#include <fmt/core.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <chrono>
#include <cstdint>
#include <cstdlib>
#include <optional>
#include <string>
#include <thread>
#include <utility>
using namespace etl::impl;
using testing::MockFunction;
using testing::StrictMock;
struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
SubscriptionSourceConnectionTests()
struct SubscriptionSourceConnectionTestsBase : public NoLoggerFixture {
SubscriptionSourceConnectionTestsBase()
{
subscriptionSource_.run();
}
@@ -52,7 +58,7 @@ struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
StrictMockSubscriptionManagerSharedPtr subscriptionManager_;
StrictMock<MockFunction<void()>> onConnectHook_;
StrictMock<MockFunction<void()>> onDisconnectHook_;
StrictMock<MockFunction<void(bool)>> onDisconnectHook_;
StrictMock<MockFunction<void()>> onLedgerClosedHook_;
SubscriptionSource subscriptionSource_{
@@ -64,8 +70,8 @@ struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
onConnectHook_.AsStdFunction(),
onDisconnectHook_.AsStdFunction(),
onLedgerClosedHook_.AsStdFunction(),
std::chrono::milliseconds(1),
std::chrono::milliseconds(1)
std::chrono::milliseconds(5),
std::chrono::milliseconds(5)
};
[[maybe_unused]] TestWsConnection
@@ -90,15 +96,17 @@ struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
}
};
struct SubscriptionSourceConnectionTests : util::prometheus::WithPrometheus, SubscriptionSourceConnectionTestsBase {};
TEST_F(SubscriptionSourceConnectionTests, ConnectionFailed)
{
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
TEST_F(SubscriptionSourceConnectionTests, ConnectionFailed_Retry_ConnectionFailed)
{
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -110,7 +118,19 @@ TEST_F(SubscriptionSourceConnectionTests, ReadError)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
TEST_F(SubscriptionSourceConnectionTests, ReadTimeout)
{
boost::asio::spawn(ioContext_, [this](boost::asio::yield_context yield) {
auto connection = serverConnection(yield);
std::this_thread::sleep_for(std::chrono::milliseconds{10});
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -124,7 +144,7 @@ TEST_F(SubscriptionSourceConnectionTests, ReadError_Reconnect)
});
EXPECT_CALL(onConnectHook_, Call()).Times(2);
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -137,14 +157,14 @@ TEST_F(SubscriptionSourceConnectionTests, IsConnected)
});
EXPECT_CALL(onConnectHook_, Call()).WillOnce([this]() { EXPECT_TRUE(subscriptionSource_.isConnected()); });
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() {
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() {
EXPECT_FALSE(subscriptionSource_.isConnected());
subscriptionSource_.stop();
});
ioContext_.run();
}
struct SubscriptionSourceReadTests : public SubscriptionSourceConnectionTests {
struct SubscriptionSourceReadTestsBase : public SubscriptionSourceConnectionTestsBase {
[[maybe_unused]] TestWsConnection
connectAndSendMessage(std::string const message, boost::asio::yield_context yield)
{
@@ -155,6 +175,8 @@ struct SubscriptionSourceReadTests : public SubscriptionSourceConnectionTests {
}
};
struct SubscriptionSourceReadTests : util::prometheus::WithPrometheus, SubscriptionSourceReadTestsBase {};
TEST_F(SubscriptionSourceReadTests, GotWrongMessage_Reconnect)
{
boost::asio::spawn(ioContext_, [this](boost::asio::yield_context yield) {
@@ -165,7 +187,7 @@ TEST_F(SubscriptionSourceReadTests, GotWrongMessage_Reconnect)
});
EXPECT_CALL(onConnectHook_, Call()).Times(2);
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -177,7 +199,7 @@ TEST_F(SubscriptionSourceReadTests, GotResult)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -189,7 +211,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithLedgerIndex)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*networkValidatedLedgers_, push(123));
ioContext_.run();
}
@@ -204,7 +226,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithLedgerIndexAsString_Reconnect)
});
EXPECT_CALL(onConnectHook_, Call()).Times(2);
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -218,7 +240,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithValidatedLedgersAsNumber_Reconn
});
EXPECT_CALL(onConnectHook_, Call()).Times(2);
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -240,7 +262,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithValidatedLedgers)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
EXPECT_TRUE(subscriptionSource_.hasLedger(123));
@@ -266,7 +288,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithValidatedLedgersWrongValue_Reco
});
EXPECT_CALL(onConnectHook_, Call()).Times(2);
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -284,7 +306,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithLedgerIndexAndValidatedLedgers)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*networkValidatedLedgers_, push(123));
ioContext_.run();
@@ -304,7 +326,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosed)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -319,7 +341,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedForwardingIsSet)
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onLedgerClosedHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() {
EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() {
EXPECT_FALSE(subscriptionSource_.isForwarding());
subscriptionSource_.stop();
});
@@ -334,7 +356,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithLedgerIndex)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*networkValidatedLedgers_, push(123));
ioContext_.run();
}
@@ -349,7 +371,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithLedgerIndexAsString_Recon
});
EXPECT_CALL(onConnectHook_, Call()).Times(2);
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -363,7 +385,7 @@ TEST_F(SubscriptionSourceReadTests, GorLedgerClosedWithValidatedLedgersAsNumber_
});
EXPECT_CALL(onConnectHook_, Call()).Times(2);
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -380,7 +402,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithValidatedLedgers)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
EXPECT_FALSE(subscriptionSource_.hasLedger(0));
@@ -404,7 +426,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithLedgerIndexAndValidatedLe
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*networkValidatedLedgers_, push(123));
ioContext_.run();
@@ -423,7 +445,7 @@ TEST_F(SubscriptionSourceReadTests, GotTransactionIsForwardingFalse)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -438,7 +460,7 @@ TEST_F(SubscriptionSourceReadTests, GotTransactionIsForwardingTrue)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*subscriptionManager_, forwardProposedTransaction(message));
ioContext_.run();
}
@@ -454,7 +476,7 @@ TEST_F(SubscriptionSourceReadTests, GotTransactionWithMetaIsForwardingFalse)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*subscriptionManager_, forwardProposedTransaction(message)).Times(0);
ioContext_.run();
}
@@ -467,7 +489,7 @@ TEST_F(SubscriptionSourceReadTests, GotValidationReceivedIsForwardingFalse)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -482,7 +504,7 @@ TEST_F(SubscriptionSourceReadTests, GotValidationReceivedIsForwardingTrue)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*subscriptionManager_, forwardValidation(message));
ioContext_.run();
}
@@ -495,7 +517,7 @@ TEST_F(SubscriptionSourceReadTests, GotManiefstReceivedIsForwardingFalse)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
}
@@ -510,7 +532,7 @@ TEST_F(SubscriptionSourceReadTests, GotManifestReceivedIsForwardingTrue)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(*subscriptionManager_, forwardManifest(message));
ioContext_.run();
}
@@ -523,7 +545,7 @@ TEST_F(SubscriptionSourceReadTests, LastMessageTime)
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
ioContext_.run();
auto const actualLastTimeMessage = subscriptionSource_.lastMessageTime();
@@ -531,3 +553,27 @@ TEST_F(SubscriptionSourceReadTests, LastMessageTime)
auto const diff = std::chrono::duration_cast<std::chrono::milliseconds>(now - actualLastTimeMessage);
EXPECT_LT(diff, std::chrono::milliseconds(100));
}
struct SubscriptionSourcePrometheusCounterTests : util::prometheus::WithMockPrometheus,
SubscriptionSourceReadTestsBase {};
TEST_F(SubscriptionSourcePrometheusCounterTests, LastMessageTime)
{
auto& lastMessageTimeMock = makeMock<util::prometheus::GaugeInt>(
"subscription_source_last_message_time", fmt::format("{{source=\"127.0.0.1:{}\"}}", wsServer_.port())
);
boost::asio::spawn(ioContext_, [this](boost::asio::yield_context yield) {
auto connection = connectAndSendMessage("some_message", yield);
connection.close(yield);
});
EXPECT_CALL(onConnectHook_, Call());
EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
EXPECT_CALL(lastMessageTimeMock, set).WillOnce([](int64_t value) {
auto const now =
std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch())
.count();
EXPECT_LE(now - value, 1);
});
ioContext_.run();
}

View File

@@ -21,6 +21,7 @@
#include "feed/impl/LedgerFeed.hpp"
#include "util/TestObject.hpp"
#include <boost/asio/io_context.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/json/parse.hpp>
#include <gmock/gmock.h>
@@ -57,11 +58,13 @@ TEST_F(FeedLedgerTest, SubPub)
"reserve_base":3,
"reserve_inc":2
})";
boost::asio::spawn(ctx, [this](boost::asio::yield_context yield) {
boost::asio::io_context ioContext;
boost::asio::spawn(ioContext, [this](boost::asio::yield_context yield) {
auto res = testFeedPtr->sub(yield, backend, sessionPtr);
// check the response
EXPECT_EQ(res, json::parse(LedgerResponse));
});
ioContext.run();
EXPECT_EQ(testFeedPtr->count(), 1);
constexpr static auto ledgerPub =

View File

@@ -28,6 +28,7 @@
#include "util/async/context/SyncExecutionContext.hpp"
#include "web/interface/ConnectionBase.hpp"
#include <boost/asio/io_context.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/json/object.hpp>
#include <boost/json/parse.hpp>
@@ -57,13 +58,12 @@ class SubscriptionManagerBaseTest : public util::prometheus::WithPrometheus, pub
protected:
std::shared_ptr<SubscriptionManager> subscriptionManagerPtr;
std::shared_ptr<web::ConnectionBase> session;
Execution ctx{2};
MockSession* sessionPtr = nullptr;
void
SetUp() override
{
subscriptionManagerPtr = std::make_shared<SubscriptionManager>(ctx, backend);
subscriptionManagerPtr = std::make_shared<SubscriptionManager>(Execution(2), backend);
session = std::make_shared<MockSession>();
session->apiSubVersion = 1;
sessionPtr = dynamic_cast<MockSession*>(session.get());
@@ -264,11 +264,13 @@ TEST_F(SubscriptionManagerTest, LedgerTest)
"reserve_base":3,
"reserve_inc":2
})";
boost::asio::io_context ctx;
boost::asio::spawn(ctx, [this](boost::asio::yield_context yield) {
auto const res = subscriptionManagerPtr->subLedger(yield, session);
// check the response
EXPECT_EQ(res, json::parse(LedgerResponse));
});
ctx.run();
EXPECT_EQ(subscriptionManagerPtr->report()["ledger"], 1);
// test publish

View File

@@ -79,7 +79,7 @@ TEST(RPCErrorsTest, StatusAsBool)
TEST(RPCErrorsTest, StatusEquals)
{
EXPECT_EQ(Status{RippledError::rpcUNKNOWN}, Status{RippledError::rpcUNKNOWN});
EXPECT_NE(Status{RippledError::rpcUNKNOWN}, Status{RippledError::rpcREPORTING_UNSUPPORTED});
EXPECT_NE(Status{RippledError::rpcUNKNOWN}, Status{RippledError::rpcINTERNAL});
}
TEST(RPCErrorsTest, SuccessToJSON)

View File

@@ -23,6 +23,7 @@
#include "util/MockCounters.hpp"
#include "util/MockHandlerProvider.hpp"
#include "util/MockLoadBalancer.hpp"
#include "util/NameGenerator.hpp"
#include "util/Taggable.hpp"
#include "util/config/Config.hpp"
#include "web/Context.hpp"
@@ -32,10 +33,12 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <variant>
#include <vector>
using namespace rpc;
using namespace testing;
@@ -59,251 +62,159 @@ protected:
};
};
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfClioOnly)
struct ShouldForwardParamTestCaseBundle {
std::string testName;
std::uint32_t apiVersion;
std::string method;
std::string testJson;
bool mockedIsClioOnly;
std::uint32_t called;
bool isAdmin;
bool expected;
};
struct ShouldForwardParameterTest : public RPCForwardingProxyTest,
WithParamInterface<ShouldForwardParamTestCaseBundle> {};
static auto
generateTestValuesForParametersTest()
{
auto const isClioOnly = true;
auto const isAdmin = true;
auto const shouldForward = true;
return std::vector<ShouldForwardParamTestCaseBundle>{
{"ShouldForwardReturnsFalseIfClioOnly", 2u, "test", "{}", isClioOnly, 1, !isAdmin, !shouldForward},
{"ShouldForwardReturnsTrueIfProxied", 2u, "submit", "{}", !isClioOnly, 1, !isAdmin, shouldForward},
{"ShouldForwardReturnsTrueIfCurrentLedgerSpecified",
2u,
"anymethod",
R"({"ledger_index": "current"})",
!isClioOnly,
1,
!isAdmin,
shouldForward},
{"ShouldForwardReturnsTrueIfClosedLedgerSpecified",
2u,
"anymethod",
R"({"ledger_index": "closed"})",
!isClioOnly,
1,
!isAdmin,
shouldForward},
{"ShouldForwardReturnsTrueIfAccountInfoWithQueueSpecified",
2u,
"account_info",
R"({"queue": true})",
!isClioOnly,
1,
!isAdmin,
shouldForward},
{"ShouldForwardReturnsFalseIfAccountInfoQueueIsFalse",
2u,
"account_info",
R"({"queue": false})",
!isClioOnly,
1,
!isAdmin,
!shouldForward},
{"ShouldForwardReturnsTrueIfLedgerWithQueueSpecified",
2u,
"ledger",
R"({"queue": true})",
!isClioOnly,
1,
!isAdmin,
shouldForward},
{"ShouldForwardReturnsFalseIfLedgerQueueIsFalse",
2u,
"ledger",
R"({"queue": false})",
!isClioOnly,
1,
!isAdmin,
!shouldForward},
{"ShouldNotForwardReturnsTrueIfAPIVersionIsV1",
1u,
"api_version_check",
"{}",
!isClioOnly,
1,
!isAdmin,
!shouldForward},
{"ShouldForwardReturnsFalseIfAPIVersionIsV2",
2u,
"api_version_check",
"{}",
!isClioOnly,
1,
!isAdmin,
!shouldForward},
{"ShouldNeverForwardSubscribe", 1u, "subscribe", "{}", !isClioOnly, 0, !isAdmin, !shouldForward},
{"ShouldNeverForwardUnsubscribe", 1u, "unsubscribe", "{}", !isClioOnly, 0, !isAdmin, !shouldForward},
{"ForceForwardTrue", 1u, "any_method", R"({"force_forward": true})", !isClioOnly, 1, isAdmin, shouldForward},
{"ForceForwardFalse", 1u, "any_method", R"({"force_forward": false})", !isClioOnly, 1, isAdmin, !shouldForward},
{"ForceForwardNotAdmin",
1u,
"any_method",
R"({"force_forward": true})",
!isClioOnly,
1,
!isAdmin,
!shouldForward},
{"ForceForwardSubscribe",
1u,
"subscribe",
R"({"force_forward": true})",
!isClioOnly,
0,
isAdmin,
not shouldForward},
{"ForceForwardUnsubscribe",
1u,
"unsubscribe",
R"({"force_forward": true})",
!isClioOnly,
0,
isAdmin,
!shouldForward},
{"ForceForwardClioOnly",
1u,
"clio_only_method",
R"({"force_forward": true})",
isClioOnly,
1,
isAdmin,
!shouldForward},
};
}
INSTANTIATE_TEST_CASE_P(
ShouldForwardTest,
ShouldForwardParameterTest,
ValuesIn(generateTestValuesForParametersTest()),
tests::util::NameGenerator
);
TEST_P(ShouldForwardParameterTest, Test)
{
auto const testBundle = GetParam();
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "test";
auto const params = json::parse("{}");
auto const apiVersion = testBundle.apiVersion;
auto const method = testBundle.method;
auto const params = json::parse(testBundle.testJson);
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(true));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(testBundle.mockedIsClioOnly));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(testBundle.called);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const ctx = web::Context(
yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, testBundle.isAdmin
);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfProxied)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "submit";
auto const params = json::parse("{}");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_TRUE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfCurrentLedgerSpecified)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "anymethod";
auto const params = json::parse(R"({"ledger_index": "current"})");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_TRUE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfClosedLedgerSpecified)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "anymethod";
auto const params = json::parse(R"({"ledger_index": "closed"})");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_TRUE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfAccountInfoWithQueueSpecified)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "account_info";
auto const params = json::parse(R"({"queue": true})");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_TRUE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfAccountInfoQueueIsFalse)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "account_info";
auto const params = json::parse(R"({"queue": false})");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfLedgerWithQueueSpecified)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "ledger";
auto const params = json::parse(R"({"queue": true})");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_TRUE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfLedgerQueueIsFalse)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "ledger";
auto const params = json::parse(R"({"queue": false})");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldNotForwardReturnsTrueIfAPIVersionIsV1)
{
auto const apiVersion = 1u;
auto const method = "api_version_check";
auto const params = json::parse("{}");
auto const rawHandlerProviderPtr = handlerProvider.get();
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfAPIVersionIsV2)
{
auto const rawHandlerProviderPtr = handlerProvider.get();
auto const apiVersion = 2u;
auto const method = "api_version_check";
auto const params = json::parse("{}");
ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldNeverForwardFeatureWithVetoedFlag)
{
auto const apiVersion = 1u;
auto const method = "feature";
auto const params = json::parse(R"({"vetoed": true, "feature": "foo"})");
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldNeverForwardSubscribe)
{
auto const apiVersion = 1u;
auto const method = "subscribe";
auto const params = json::parse("{}");
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
});
}
TEST_F(RPCForwardingProxyTest, ShouldNeverForwardUnsubscribe)
{
auto const apiVersion = 1u;
auto const method = "unsubscribe";
auto const params = json::parse("{}");
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
auto const ctx =
web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
auto const res = proxy.shouldForward(ctx);
ASSERT_FALSE(res);
ASSERT_EQ(res, testBundle.expected);
});
}

View File

@@ -25,6 +25,7 @@
#include "util/AsioContextTestFixture.hpp"
#include "util/MockBackendTestFixture.hpp"
#include "util/MockPrometheus.hpp"
#include "util/NameGenerator.hpp"
#include "util/TestObject.hpp"
#include <boost/asio/impl/spawn.hpp>
@@ -539,3 +540,67 @@ TEST_F(RPCHelpersTest, ParseIssue)
std::runtime_error
);
}
struct IsAdminCmdParamTestCaseBundle {
std::string testName;
std::string method;
std::string testJson;
bool expected;
};
struct IsAdminCmdParameterTest : public TestWithParam<IsAdminCmdParamTestCaseBundle> {};
static auto
generateTestValuesForParametersTest()
{
return std::vector<IsAdminCmdParamTestCaseBundle>{
{"ledgerEntry", "ledger_entry", R"({"type": false})", false},
{"featureVetoedTrue", "feature", R"({"vetoed": true, "feature": "foo"})", true},
{"featureVetoedFalse", "feature", R"({"vetoed": false, "feature": "foo"})", true},
{"featureVetoedIsStr", "feature", R"({"vetoed": "String"})", true},
{"ledger", "ledger", R"({})", false},
{"ledgerWithType", "ledger", R"({"type": "fee"})", false},
{"ledgerFullTrue", "ledger", R"({"full": true})", true},
{"ledgerFullFalse", "ledger", R"({"full": false})", false},
{"ledgerFullIsStr", "ledger", R"({"full": "String"})", true},
{"ledgerFullIsEmptyStr", "ledger", R"({"full": ""})", false},
{"ledgerFullIsNumber1", "ledger", R"({"full": 1})", true},
{"ledgerFullIsNumber0", "ledger", R"({"full": 0})", false},
{"ledgerFullIsNull", "ledger", R"({"full": null})", false},
{"ledgerFullIsFloat0", "ledger", R"({"full": 0.0})", false},
{"ledgerFullIsFloat1", "ledger", R"({"full": 0.1})", true},
{"ledgerFullIsArray", "ledger", R"({"full": [1]})", true},
{"ledgerFullIsEmptyArray", "ledger", R"({"full": []})", false},
{"ledgerFullIsObject", "ledger", R"({"full": {"key": 1}})", true},
{"ledgerFullIsEmptyObject", "ledger", R"({"full": {}})", false},
{"ledgerAccountsTrue", "ledger", R"({"accounts": true})", true},
{"ledgerAccountsFalse", "ledger", R"({"accounts": false})", false},
{"ledgerAccountsIsStr", "ledger", R"({"accounts": "String"})", true},
{"ledgerAccountsIsEmptyStr", "ledger", R"({"accounts": ""})", false},
{"ledgerAccountsIsNumber1", "ledger", R"({"accounts": 1})", true},
{"ledgerAccountsIsNumber0", "ledger", R"({"accounts": 0})", false},
{"ledgerAccountsIsNull", "ledger", R"({"accounts": null})", false},
{"ledgerAccountsIsFloat0", "ledger", R"({"accounts": 0.0})", false},
{"ledgerAccountsIsFloat1", "ledger", R"({"accounts": 0.1})", true},
{"ledgerAccountsIsArray", "ledger", R"({"accounts": [1]})", true},
{"ledgerAccountsIsEmptyArray", "ledger", R"({"accounts": []})", false},
{"ledgerAccountsIsObject", "ledger", R"({"accounts": {"key": 1}})", true},
{"ledgerAccountsIsEmptyObject", "ledger", R"({"accounts": {}})", false},
};
}
INSTANTIATE_TEST_CASE_P(
IsAdminCmdTest,
IsAdminCmdParameterTest,
ValuesIn(generateTestValuesForParametersTest()),
tests::util::NameGenerator
);
TEST_P(IsAdminCmdParameterTest, Test)
{
auto const testBundle = GetParam();
EXPECT_EQ(isAdminCmd(testBundle.method, boost::json::parse(testBundle.testJson).as_object()), testBundle.expected);
}

View File

@@ -49,6 +49,7 @@ constexpr static auto TAXON = 0;
constexpr static auto FLAG = 8;
constexpr static auto TXNID = "E6DBAFC99223B42257915A63DFC6B0C032D4070F9A574B255AD97466726FC321";
constexpr static auto PAGE = "E6DBAFC99223B42257915A63DFC6B0C032D4070F9A574B255AD97466726FC322";
constexpr static auto INVALIDPAGE = "E6DBAFC99223B42257915A63DFC6B0C032D4070F9A574B255AD97466726FCAAA";
constexpr static auto MAXSEQ = 30;
constexpr static auto MINSEQ = 10;
@@ -402,6 +403,98 @@ TEST_F(RPCAccountNFTsHandlerTest, Marker)
});
}
TEST_F(RPCAccountNFTsHandlerTest, InvalidMarker)
{
backend->setRange(MINSEQ, MAXSEQ);
auto const ledgerHeader = CreateLedgerHeader(LEDGERHASH, MAXSEQ);
EXPECT_CALL(*backend, fetchLedgerBySequence).Times(1);
ON_CALL(*backend, fetchLedgerBySequence).WillByDefault(Return(ledgerHeader));
auto const accountObject = CreateAccountRootObject(ACCOUNT, 0, 1, 10, 2, TXNID, 3);
auto const accountID = GetAccountIDWithString(ACCOUNT);
ON_CALL(*backend, doFetchLedgerObject(ripple::keylet::account(accountID).key, 30, _))
.WillByDefault(Return(accountObject.getSerializer().peekData()));
auto static const input = json::parse(fmt::format(
R"({{
"account":"{}",
"marker":"{}"
}})",
ACCOUNT,
INVALIDPAGE
));
auto const handler = AnyHandler{AccountNFTsHandler{backend}};
runSpawn([&](auto yield) {
auto const output = handler.process(input, Context{yield});
ASSERT_FALSE(output);
auto const err = rpc::makeError(output.result.error());
EXPECT_EQ(err.at("error").as_string(), "invalidParams");
EXPECT_EQ(err.at("error_message").as_string(), "Marker field does not match any valid Page ID");
});
}
TEST_F(RPCAccountNFTsHandlerTest, AccountWithNoNFT)
{
backend->setRange(MINSEQ, MAXSEQ);
auto const ledgerHeader = CreateLedgerHeader(LEDGERHASH, MAXSEQ);
EXPECT_CALL(*backend, fetchLedgerBySequence).Times(1);
ON_CALL(*backend, fetchLedgerBySequence).WillByDefault(Return(ledgerHeader));
auto const accountObject = CreateAccountRootObject(ACCOUNT, 0, 1, 10, 2, TXNID, 3);
auto const accountID = GetAccountIDWithString(ACCOUNT);
ON_CALL(*backend, doFetchLedgerObject(ripple::keylet::account(accountID).key, 30, _))
.WillByDefault(Return(accountObject.getSerializer().peekData()));
auto static const input = json::parse(fmt::format(
R"({{
"account":"{}"
}})",
ACCOUNT
));
auto const handler = AnyHandler{AccountNFTsHandler{backend}};
runSpawn([&](auto yield) {
auto const output = handler.process(input, Context{yield});
ASSERT_TRUE(output);
EXPECT_EQ(output.result->as_object().at("account_nfts").as_array().size(), 0);
});
}
TEST_F(RPCAccountNFTsHandlerTest, invalidPage)
{
backend->setRange(MINSEQ, MAXSEQ);
auto const ledgerHeader = CreateLedgerHeader(LEDGERHASH, MAXSEQ);
EXPECT_CALL(*backend, fetchLedgerBySequence).Times(1);
ON_CALL(*backend, fetchLedgerBySequence).WillByDefault(Return(ledgerHeader));
auto const accountObject = CreateAccountRootObject(ACCOUNT, 0, 1, 10, 2, TXNID, 3);
auto const accountID = GetAccountIDWithString(ACCOUNT);
ON_CALL(*backend, doFetchLedgerObject(ripple::keylet::account(accountID).key, 30, _))
.WillByDefault(Return(accountObject.getSerializer().peekData()));
auto const pageObject =
CreateNFTTokenPage(std::vector{std::make_pair<std::string, std::string>(TOKENID, "www.ok.com")}, std::nullopt);
ON_CALL(*backend, doFetchLedgerObject(ripple::uint256{PAGE}, 30, _))
.WillByDefault(Return(accountObject.getSerializer().peekData()));
EXPECT_CALL(*backend, doFetchLedgerObject).Times(2);
auto static const input = json::parse(fmt::format(
R"({{
"account":"{}",
"marker":"{}"
}})",
ACCOUNT,
PAGE
));
auto const handler = AnyHandler{AccountNFTsHandler{backend}};
runSpawn([&](auto yield) {
auto const output = handler.process(input, Context{yield});
ASSERT_FALSE(output);
auto const err = rpc::makeError(output.result.error());
EXPECT_EQ(err.at("error").as_string(), "invalidParams");
EXPECT_EQ(err.at("error_message").as_string(), "Marker matches Page ID from another Account");
});
}
TEST_F(RPCAccountNFTsHandlerTest, LimitLessThanMin)
{
static auto const expectedOutput = fmt::format(

View File

@@ -136,22 +136,13 @@ generateTestValuesForParametersTest()
SubscribeParamTestCaseBundle{"StreamNotString", R"({"streams": [1]})", "invalidParams", "streamNotString"},
SubscribeParamTestCaseBundle{"StreamNotValid", R"({"streams": ["1"]})", "malformedStream", "Stream malformed."},
SubscribeParamTestCaseBundle{
"StreamPeerStatusNotSupport",
R"({"streams": ["peer_status"]})",
"reportingUnsupported",
"Requested operation not supported by reporting mode server"
"StreamPeerStatusNotSupport", R"({"streams": ["peer_status"]})", "notSupported", "Operation not supported."
},
SubscribeParamTestCaseBundle{
"StreamConsensusNotSupport",
R"({"streams": ["consensus"]})",
"reportingUnsupported",
"Requested operation not supported by reporting mode server"
"StreamConsensusNotSupport", R"({"streams": ["consensus"]})", "notSupported", "Operation not supported."
},
SubscribeParamTestCaseBundle{
"StreamServerNotSupport",
R"({"streams": ["server"]})",
"reportingUnsupported",
"Requested operation not supported by reporting mode server"
"StreamServerNotSupport", R"({"streams": ["server"]})", "notSupported", "Operation not supported."
},
SubscribeParamTestCaseBundle{"BooksNotArray", R"({"books": "1"})", "invalidParams", "booksNotArray"},
SubscribeParamTestCaseBundle{

View File

@@ -486,22 +486,13 @@ generateTestValuesForParametersTest()
"bothNotBool"
},
UnsubscribeParamTestCaseBundle{
"StreamPeerStatusNotSupport",
R"({"streams": ["peer_status"]})",
"reportingUnsupported",
"Requested operation not supported by reporting mode server"
"StreamPeerStatusNotSupport", R"({"streams": ["peer_status"]})", "notSupported", "Operation not supported."
},
UnsubscribeParamTestCaseBundle{
"StreamConsensusNotSupport",
R"({"streams": ["consensus"]})",
"reportingUnsupported",
"Requested operation not supported by reporting mode server"
"StreamConsensusNotSupport", R"({"streams": ["consensus"]})", "notSupported", "Operation not supported."
},
UnsubscribeParamTestCaseBundle{
"StreamServerNotSupport",
R"({"streams": ["server"]})",
"reportingUnsupported",
"Requested operation not supported by reporting mode server"
"StreamServerNotSupport", R"({"streams": ["server"]})", "notSupported", "Operation not supported."
},
};
}

View File

@@ -18,33 +18,69 @@
//==============================================================================
#include "util/newconfig/Array.hpp"
#include "util/newconfig/ConfigConstraints.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Types.hpp"
#include "util/newconfig/ValueView.hpp"
#include <gtest/gtest.h>
#include <algorithm>
#include <cstdint>
#include <vector>
using namespace util::config;
TEST(ArrayTest, testConfigArray)
TEST(ArrayTest, addSingleValue)
{
auto arr = Array{
ConfigValue{ConfigType::Boolean}.defaultValue(false),
ConfigValue{ConfigType::Integer}.defaultValue(1234),
ConfigValue{ConfigType::Double}.defaultValue(22.22),
};
auto cv = arr.at(0);
ValueView const vv{cv};
EXPECT_EQ(vv.asBool(), false);
auto arr = Array{ConfigValue{ConfigType::Double}};
arr.addValue(111.11);
EXPECT_EQ(arr.size(), 1);
}
auto cv2 = arr.at(1);
TEST(ArrayTest, addAndCheckMultipleValues)
{
auto arr = Array{ConfigValue{ConfigType::Double}};
arr.addValue(111.11);
arr.addValue(222.22);
arr.addValue(333.33);
EXPECT_EQ(arr.size(), 3);
auto const cv = arr.at(0);
ValueView const vv{cv};
EXPECT_EQ(vv.asDouble(), 111.11);
auto const cv2 = arr.at(1);
ValueView const vv2{cv2};
EXPECT_EQ(vv2.asIntType<int>(), 1234);
EXPECT_EQ(vv2.asDouble(), 222.22);
EXPECT_EQ(arr.size(), 3);
arr.emplaceBack(ConfigValue{ConfigType::String}.defaultValue("false"));
arr.addValue(444.44);
EXPECT_EQ(arr.size(), 4);
auto cv4 = arr.at(3);
auto const cv4 = arr.at(3);
ValueView const vv4{cv4};
EXPECT_EQ(vv4.asString(), "false");
EXPECT_EQ(vv4.asDouble(), 444.44);
}
TEST(ArrayTest, testArrayPattern)
{
auto const arr = Array{ConfigValue{ConfigType::String}};
auto const arrPattern = arr.getArrayPattern();
EXPECT_EQ(arrPattern.type(), ConfigType::String);
}
TEST(ArrayTest, iterateValueArray)
{
auto arr = Array{ConfigValue{ConfigType::Integer}.withConstraint(validateUint16)};
std::vector<int64_t> const expected{543, 123, 909};
for (auto const num : expected)
arr.addValue(num);
std::vector<int64_t> actual;
for (auto it = arr.begin(); it != arr.end(); ++it)
actual.emplace_back(std::get<int64_t>(it->getValue()));
EXPECT_TRUE(std::ranges::equal(expected, actual));
}

View File

@@ -19,11 +19,13 @@
#include "util/newconfig/ArrayView.hpp"
#include "util/newconfig/ConfigDefinition.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/ConfigFileJson.hpp"
#include "util/newconfig/FakeConfigData.hpp"
#include "util/newconfig/ObjectView.hpp"
#include "util/newconfig/Types.hpp"
#include "util/newconfig/ValueView.hpp"
#include <boost/json/parse.hpp>
#include <gtest/gtest.h>
#include <cstddef>
@@ -31,33 +33,65 @@
using namespace util::config;
struct ArrayViewTest : testing::Test {
ClioConfigDefinition const configData = generateConfig();
ArrayViewTest()
{
ConfigFileJson const jsonFileObj{boost::json::parse(JSONData).as_object()};
auto const errors = configData.parse(jsonFileObj);
EXPECT_TRUE(!errors.has_value());
}
ClioConfigDefinition configData = generateConfig();
};
TEST_F(ArrayViewTest, ArrayValueTest)
// Array View tests can only be tested after the values are populated from user Config
// into ConfigClioDefinition
TEST_F(ArrayViewTest, ArrayGetValueDouble)
{
ArrayView const arrVals = configData.getArray("array.[].sub");
auto valIt = arrVals.begin<ValueView>();
auto const precision = 1e-9;
EXPECT_NEAR((*valIt++).asDouble(), 111.11, precision);
EXPECT_NEAR((*valIt++).asDouble(), 4321.55, precision);
EXPECT_EQ(valIt, arrVals.end<ValueView>());
ArrayView const arrVals = configData.getArray("array.[].sub");
EXPECT_NEAR(111.11, arrVals.valueAt(0).asDouble(), precision);
auto const firstVal = arrVals.valueAt(0);
EXPECT_EQ(firstVal.type(), ConfigType::Double);
EXPECT_TRUE(firstVal.hasValue());
EXPECT_FALSE(firstVal.isOptional());
EXPECT_NEAR(111.11, firstVal.asDouble(), precision);
EXPECT_NEAR(4321.55, arrVals.valueAt(1).asDouble(), precision);
ArrayView const arrVals2 = configData.getArray("array.[].sub2");
auto val2It = arrVals2.begin<ValueView>();
EXPECT_EQ((*val2It++).asString(), "subCategory");
EXPECT_EQ((*val2It++).asString(), "temporary");
EXPECT_EQ(val2It, arrVals2.end<ValueView>());
ValueView const tempVal = arrVals2.valueAt(0);
EXPECT_EQ(tempVal.type(), ConfigType::String);
EXPECT_EQ("subCategory", tempVal.asString());
}
TEST_F(ArrayViewTest, ArrayWithObjTest)
TEST_F(ArrayViewTest, ArrayGetValueString)
{
ArrayView const arrVals = configData.getArray("array.[].sub2");
ValueView const firstVal = arrVals.valueAt(0);
EXPECT_EQ(firstVal.type(), ConfigType::String);
EXPECT_EQ("subCategory", firstVal.asString());
EXPECT_EQ("london", arrVals.valueAt(2).asString());
}
TEST_F(ArrayViewTest, IterateValuesDouble)
{
auto const precision = 1e-9;
ArrayView const arrVals = configData.getArray("array.[].sub");
auto valIt = arrVals.begin<ValueView>();
EXPECT_NEAR((*valIt++).asDouble(), 111.11, precision);
EXPECT_NEAR((*valIt++).asDouble(), 4321.55, precision);
EXPECT_NEAR((*valIt++).asDouble(), 5555.44, precision);
EXPECT_EQ(valIt, arrVals.end<ValueView>());
}
TEST_F(ArrayViewTest, IterateValuesString)
{
ArrayView const arrVals = configData.getArray("array.[].sub2");
auto val2It = arrVals.begin<ValueView>();
EXPECT_EQ((*val2It++).asString(), "subCategory");
EXPECT_EQ((*val2It++).asString(), "temporary");
EXPECT_EQ((*val2It++).asString(), "london");
EXPECT_EQ(val2It, arrVals.end<ValueView>());
}
TEST_F(ArrayViewTest, ArrayWithObj)
{
ArrayView const arrVals = configData.getArray("array.[]");
ArrayView const arrValAlt = configData.getArray("array");
@@ -73,20 +107,19 @@ TEST_F(ArrayViewTest, IterateArray)
{
auto arr = configData.getArray("dosguard.whitelist");
EXPECT_EQ(2, arr.size());
EXPECT_EQ(arr.valueAt(0).asString(), "125.5.5.2");
EXPECT_EQ(arr.valueAt(1).asString(), "204.2.2.2");
EXPECT_EQ(arr.valueAt(0).asString(), "125.5.5.1");
EXPECT_EQ(arr.valueAt(1).asString(), "204.2.2.1");
auto it = arr.begin<ValueView>();
EXPECT_EQ((*it++).asString(), "125.5.5.2");
EXPECT_EQ((*it++).asString(), "204.2.2.2");
EXPECT_EQ((*it++).asString(), "125.5.5.1");
EXPECT_EQ((*it++).asString(), "204.2.2.1");
EXPECT_EQ((it), arr.end<ValueView>());
}
TEST_F(ArrayViewTest, DifferentArrayIterators)
TEST_F(ArrayViewTest, CompareDifferentArrayIterators)
{
auto const subArray = configData.getArray("array.[].sub");
auto const dosguardArray = configData.getArray("dosguard.whitelist.[]");
ASSERT_EQ(subArray.size(), dosguardArray.size());
auto itArray = subArray.begin<ValueView>();
auto itDosguard = dosguardArray.begin<ValueView>();
@@ -98,7 +131,7 @@ TEST_F(ArrayViewTest, DifferentArrayIterators)
TEST_F(ArrayViewTest, IterateObject)
{
auto arr = configData.getArray("array");
EXPECT_EQ(2, arr.size());
EXPECT_EQ(3, arr.size());
auto it = arr.begin<ObjectView>();
EXPECT_EQ(111.11, (*it).getValue("sub").asDouble());
@@ -107,33 +140,37 @@ TEST_F(ArrayViewTest, IterateObject)
EXPECT_EQ(4321.55, (*it).getValue("sub").asDouble());
EXPECT_EQ("temporary", (*it++).getValue("sub2").asString());
EXPECT_EQ(5555.44, (*it).getValue("sub").asDouble());
EXPECT_EQ("london", (*it++).getValue("sub2").asString());
EXPECT_EQ(it, arr.end<ObjectView>());
}
struct ArrayViewDeathTest : ArrayViewTest {};
TEST_F(ArrayViewDeathTest, IncorrectAccess)
TEST_F(ArrayViewDeathTest, AccessArrayOutOfBounce)
{
ArrayView const arr = configData.getArray("higher");
// dies because higher only has 1 object (trying to access 2nd element)
EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getArray("higher").objectAt(1); }, ".*");
}
// dies because higher only has 1 object
EXPECT_DEATH({ [[maybe_unused]] auto _ = arr.objectAt(1); }, ".*");
ArrayView const arrVals2 = configData.getArray("array.[].sub2");
ValueView const tempVal = arrVals2.valueAt(0);
// dies because array.[].sub2 only has 2 config values
EXPECT_DEATH([[maybe_unused]] auto _ = arrVals2.valueAt(2), ".*");
TEST_F(ArrayViewDeathTest, AccessIndexOfWrongType)
{
auto const& arrVals2 = configData.getArray("array.[].sub2");
auto const& tempVal = arrVals2.valueAt(0);
// dies as value is not of type int
EXPECT_DEATH({ [[maybe_unused]] auto _ = tempVal.asIntType<int>(); }, ".*");
}
TEST_F(ArrayViewDeathTest, IncorrectIterateAccess)
TEST_F(ArrayViewDeathTest, GetValueWhenItIsObject)
{
ArrayView const arr = configData.getArray("higher");
EXPECT_DEATH({ [[maybe_unused]] auto _ = arr.begin<ValueView>(); }, ".*");
}
TEST_F(ArrayViewDeathTest, GetObjectWhenItIsValue)
{
ArrayView const dosguardWhitelist = configData.getArray("dosguard.whitelist");
EXPECT_DEATH({ [[maybe_unused]] auto _ = dosguardWhitelist.begin<ObjectView>(); }, ".*");
}

View File

@@ -20,17 +20,25 @@
#include "util/newconfig/ArrayView.hpp"
#include "util/newconfig/ConfigDefinition.hpp"
#include "util/newconfig/ConfigDescription.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/ConfigFileJson.hpp"
#include "util/newconfig/FakeConfigData.hpp"
#include "util/newconfig/Types.hpp"
#include "util/newconfig/ValueView.hpp"
#include <boost/json/object.hpp>
#include <boost/json/parse.hpp>
#include <boost/json/value.hpp>
#include <gtest/gtest.h>
#include <algorithm>
#include <cstdint>
#include <string>
#include <string_view>
#include <unordered_set>
#include <vector>
using namespace util::config;
// TODO: parsing config file and populating into config will be here once implemented
struct NewConfigTest : testing::Test {
ClioConfigDefinition const configData = generateConfig();
};
@@ -45,12 +53,9 @@ TEST_F(NewConfigTest, fetchValues)
EXPECT_EQ(true, configData.getValue("header.admin").asBool());
EXPECT_EQ("TSM", configData.getValue("header.sub.sub2Value").asString());
EXPECT_EQ(444.22, configData.getValue("ip").asDouble());
auto const v2 = configData.getValueInArray("dosguard.whitelist", 0);
EXPECT_EQ(v2.asString(), "125.5.5.2");
}
TEST_F(NewConfigTest, fetchObject)
TEST_F(NewConfigTest, fetchObjectDirectly)
{
auto const obj = configData.getObject("header");
EXPECT_TRUE(obj.containsKey("sub.sub2Value"));
@@ -58,27 +63,6 @@ TEST_F(NewConfigTest, fetchObject)
auto const obj2 = obj.getObject("sub");
EXPECT_TRUE(obj2.containsKey("sub2Value"));
EXPECT_EQ(obj2.getValue("sub2Value").asString(), "TSM");
auto const objInArr = configData.getObject("array", 0);
auto const obj2InArr = configData.getObject("array", 1);
EXPECT_EQ(objInArr.getValue("sub").asDouble(), 111.11);
EXPECT_EQ(objInArr.getValue("sub2").asString(), "subCategory");
EXPECT_EQ(obj2InArr.getValue("sub").asDouble(), 4321.55);
EXPECT_EQ(obj2InArr.getValue("sub2").asString(), "temporary");
}
TEST_F(NewConfigTest, fetchArray)
{
auto const obj = configData.getObject("dosguard");
EXPECT_TRUE(obj.containsKey("whitelist.[]"));
auto const arr = obj.getArray("whitelist");
EXPECT_EQ(2, arr.size());
auto const sameArr = configData.getArray("dosguard.whitelist");
EXPECT_EQ(2, sameArr.size());
EXPECT_EQ(sameArr.valueAt(0).asString(), arr.valueAt(0).asString());
EXPECT_EQ(sameArr.valueAt(1).asString(), arr.valueAt(1).asString());
}
TEST_F(NewConfigTest, CheckKeys)
@@ -91,9 +75,11 @@ TEST_F(NewConfigTest, CheckKeys)
EXPECT_TRUE(configData.hasItemsWithPrefix("dosguard"));
EXPECT_TRUE(configData.hasItemsWithPrefix("ip"));
EXPECT_EQ(configData.arraySize("array"), 2);
EXPECT_EQ(configData.arraySize("higher"), 1);
EXPECT_EQ(configData.arraySize("dosguard.whitelist"), 2);
// all arrays currently not populated, only has "itemPattern_" that defines
// the type/constraint each configValue will have later on
EXPECT_EQ(configData.arraySize("array"), 0);
EXPECT_EQ(configData.arraySize("higher"), 0);
EXPECT_EQ(configData.arraySize("dosguard.whitelist"), 0);
}
TEST_F(NewConfigTest, CheckAllKeys)
@@ -110,7 +96,10 @@ TEST_F(NewConfigTest, CheckAllKeys)
"higher.[].low.section",
"higher.[].low.admin",
"dosguard.whitelist.[]",
"dosguard.port"
"dosguard.port",
"optional.withDefault",
"optional.withNoDefault",
"requireValue"
};
for (auto i = configData.begin(); i != configData.end(); ++i) {
@@ -121,31 +110,42 @@ TEST_F(NewConfigTest, CheckAllKeys)
struct NewConfigDeathTest : NewConfigTest {};
TEST_F(NewConfigDeathTest, IncorrectGetValues)
TEST_F(NewConfigDeathTest, GetNonExistentKeys)
{
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("head"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("head."); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("asdf"); }, ".*");
}
TEST_F(NewConfigDeathTest, GetValueButIsArray)
{
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("dosguard.whitelist"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("dosguard.whitelist.[]"); }, ".*");
}
TEST_F(NewConfigDeathTest, IncorrectGetObject)
TEST_F(NewConfigDeathTest, GetNonExistentObjectKey)
{
ASSERT_FALSE(configData.contains("head"));
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("head"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array", 2); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("doesNotExist"); }, ".*");
}
TEST_F(NewConfigDeathTest, IncorrectGetArray)
TEST_F(NewConfigDeathTest, GetObjectButIsArray)
{
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array", 2); }, ".*");
}
TEST_F(NewConfigDeathTest, GetArrayButIsValue)
{
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getArray("header.text1"); }, ".*");
}
TEST_F(NewConfigDeathTest, GetNonExistentArrayKey)
{
EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getArray("asdf"); }, ".*");
}
TEST(ConfigDescription, getValues)
TEST(ConfigDescription, GetValues)
{
ClioConfigDescription const definition{};
@@ -154,10 +154,150 @@ TEST(ConfigDescription, getValues)
EXPECT_EQ(definition.get("prometheus.enabled"), "Enable or disable Prometheus metrics.");
}
TEST(ConfigDescriptionAssertDeathTest, nonExistingKeyTest)
TEST(ConfigDescriptionAssertDeathTest, NonExistingKeyTest)
{
ClioConfigDescription const definition{};
EXPECT_DEATH({ [[maybe_unused]] auto a = definition.get("data"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto a = definition.get("etl_source.[]"); }, ".*");
}
/** @brief Testing override the default values with the ones in Json */
struct OverrideConfigVals : testing::Test {
OverrideConfigVals()
{
ConfigFileJson const jsonFileObj{boost::json::parse(JSONData).as_object()};
auto const errors = configData.parse(jsonFileObj);
EXPECT_TRUE(!errors.has_value());
}
ClioConfigDefinition configData = generateConfig();
};
TEST_F(OverrideConfigVals, ValidateValuesStrings)
{
// make sure the values in configData are overriden
EXPECT_TRUE(configData.contains("header.text1"));
EXPECT_EQ(configData.getValue("header.text1").asString(), "value");
EXPECT_FALSE(configData.contains("header.sub"));
EXPECT_TRUE(configData.contains("header.sub.sub2Value"));
EXPECT_EQ(configData.getValue("header.sub.sub2Value").asString(), "TSM");
EXPECT_TRUE(configData.contains("requireValue"));
EXPECT_EQ(configData.getValue("requireValue").asString(), "required");
}
TEST_F(OverrideConfigVals, ValidateValuesDouble)
{
EXPECT_TRUE(configData.contains("optional.withDefault"));
EXPECT_EQ(configData.getValue("optional.withDefault").asDouble(), 0.0);
// make sure the values not overwritten, (default values) are there too
EXPECT_TRUE(configData.contains("ip"));
EXPECT_EQ(configData.getValue("ip").asDouble(), 444.22);
}
TEST_F(OverrideConfigVals, ValidateValuesInteger)
{
EXPECT_TRUE(configData.contains("dosguard.port"));
EXPECT_EQ(configData.getValue("dosguard.port").asIntType<int>(), 44444);
EXPECT_TRUE(configData.contains("header.port"));
EXPECT_EQ(configData.getValue("header.port").asIntType<int64_t>(), 321);
}
TEST_F(OverrideConfigVals, ValidateValuesBool)
{
EXPECT_TRUE(configData.contains("header.admin"));
EXPECT_EQ(configData.getValue("header.admin").asBool(), false);
}
TEST_F(OverrideConfigVals, ValidateIntegerValuesInArrays)
{
// Check array values (sub)
EXPECT_TRUE(configData.contains("array.[].sub"));
auto const arrSub = configData.getArray("array.[].sub");
std::vector<double> expectedArrSubVal{111.11, 4321.55, 5555.44};
std::vector<double> actualArrSubVal{};
for (auto it = arrSub.begin<ValueView>(); it != arrSub.end<ValueView>(); ++it) {
actualArrSubVal.emplace_back((*it).asDouble());
}
EXPECT_TRUE(std::ranges::equal(expectedArrSubVal, actualArrSubVal));
}
TEST_F(OverrideConfigVals, ValidateStringValuesInArrays)
{
// Check array values (sub2)
EXPECT_TRUE(configData.contains("array.[].sub2"));
auto const arrSub2 = configData.getArray("array.[].sub2");
std::vector<std::string> expectedArrSub2Val{"subCategory", "temporary", "london"};
std::vector<std::string> actualArrSub2Val{};
for (auto it = arrSub2.begin<ValueView>(); it != arrSub2.end<ValueView>(); ++it) {
actualArrSub2Val.emplace_back((*it).asString());
}
EXPECT_TRUE(std::ranges::equal(expectedArrSub2Val, actualArrSub2Val));
// Check dosguard values
EXPECT_TRUE(configData.contains("dosguard.whitelist.[]"));
auto const dosguard = configData.getArray("dosguard.whitelist.[]");
EXPECT_EQ("125.5.5.1", dosguard.valueAt(0).asString());
EXPECT_EQ("204.2.2.1", dosguard.valueAt(1).asString());
}
TEST_F(OverrideConfigVals, FetchArray)
{
auto const obj = configData.getObject("dosguard");
EXPECT_TRUE(obj.containsKey("whitelist.[]"));
auto const arr = obj.getArray("whitelist");
EXPECT_EQ(2, arr.size());
auto const sameArr = configData.getArray("dosguard.whitelist");
EXPECT_EQ(2, sameArr.size());
EXPECT_EQ(sameArr.valueAt(0).asString(), arr.valueAt(0).asString());
EXPECT_EQ(sameArr.valueAt(1).asString(), arr.valueAt(1).asString());
}
TEST_F(OverrideConfigVals, FetchObjectByArray)
{
auto const objInArr = configData.getObject("array", 0);
auto const obj2InArr = configData.getObject("array", 1);
auto const obj3InArr = configData.getObject("array", 2);
EXPECT_EQ(objInArr.getValue("sub").asDouble(), 111.11);
EXPECT_EQ(objInArr.getValue("sub2").asString(), "subCategory");
EXPECT_EQ(obj2InArr.getValue("sub").asDouble(), 4321.55);
EXPECT_EQ(obj2InArr.getValue("sub2").asString(), "temporary");
EXPECT_EQ(obj3InArr.getValue("sub").asDouble(), 5555.44);
EXPECT_EQ(obj3InArr.getValue("sub2").asString(), "london");
}
struct IncorrectOverrideValues : testing::Test {
ClioConfigDefinition configData = generateConfig();
};
TEST_F(IncorrectOverrideValues, InvalidJsonErrors)
{
ConfigFileJson const jsonFileObj{boost::json::parse(invalidJSONData).as_object()};
auto const errors = configData.parse(jsonFileObj);
EXPECT_TRUE(errors.has_value());
// Expected error messages
std::unordered_set<std::string_view> const expectedErrors{
"dosguard.whitelist.[] value does not match type string",
"higher.[].low.section key is required in user Config",
"higher.[].low.admin key is required in user Config",
"array.[].sub key is required in user Config",
"header.port value does not match type integer",
"header.admin value does not match type boolean",
"optional.withDefault value does not match type double"
};
std::unordered_set<std::string_view> actualErrors;
for (auto const& error : errors.value()) {
actualErrors.insert(error.error);
}
EXPECT_EQ(expectedErrors, actualErrors);
}

View File

@@ -17,24 +17,207 @@
*/
//==============================================================================
#include "util/newconfig/ConfigConstraints.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/Types.hpp"
#include <fmt/core.h>
#include <gtest/gtest.h>
#include <array>
#include <string>
using namespace util::config;
TEST(ConfigValue, testConfigValue)
TEST(ConfigValue, GetSetString)
{
auto cvStr = ConfigValue{ConfigType::String}.defaultValue("12345");
auto const cvStr = ConfigValue{ConfigType::String}.defaultValue("12345");
EXPECT_EQ(cvStr.type(), ConfigType::String);
EXPECT_TRUE(cvStr.hasValue());
EXPECT_FALSE(cvStr.isOptional());
}
auto cvInt = ConfigValue{ConfigType::Integer}.defaultValue(543);
TEST(ConfigValue, GetSetInteger)
{
auto const cvInt = ConfigValue{ConfigType::Integer}.defaultValue(543);
EXPECT_EQ(cvInt.type(), ConfigType::Integer);
EXPECT_TRUE(cvStr.hasValue());
EXPECT_FALSE(cvStr.isOptional());
EXPECT_TRUE(cvInt.hasValue());
EXPECT_FALSE(cvInt.isOptional());
auto cvOpt = ConfigValue{ConfigType::Integer}.optional();
auto const cvOpt = ConfigValue{ConfigType::Integer}.optional();
EXPECT_TRUE(cvOpt.isOptional());
}
// A test for each constraint so it's easy to change in the future
TEST(ConfigValue, PortConstraint)
{
auto const portConstraint{PortConstraint{}};
EXPECT_FALSE(portConstraint.checkConstraint(4444).has_value());
EXPECT_TRUE(portConstraint.checkConstraint(99999).has_value());
}
TEST(ConfigValue, SetValuesOnPortConstraint)
{
auto cvPort = ConfigValue{ConfigType::Integer}.defaultValue(4444).withConstraint(validatePort);
auto const err = cvPort.setValue(99999);
EXPECT_TRUE(err.has_value());
EXPECT_EQ(err->error, "Port does not satisfy the constraint bounds");
EXPECT_TRUE(cvPort.setValue(33.33).has_value());
EXPECT_TRUE(cvPort.setValue(33.33).value().error == "value does not match type integer");
EXPECT_FALSE(cvPort.setValue(1).has_value());
auto cvPort2 = ConfigValue{ConfigType::String}.defaultValue("4444").withConstraint(validatePort);
auto const strPortError = cvPort2.setValue("100000");
EXPECT_TRUE(strPortError.has_value());
EXPECT_EQ(strPortError->error, "Port does not satisfy the constraint bounds");
}
TEST(ConfigValue, OneOfConstraintOneValue)
{
std::array<char const*, 1> const arr = {"tracer"};
auto const databaseConstraint{OneOf{"database.type", arr}};
EXPECT_FALSE(databaseConstraint.checkConstraint("tracer").has_value());
EXPECT_TRUE(databaseConstraint.checkConstraint(345).has_value());
EXPECT_EQ(databaseConstraint.checkConstraint(345)->error, R"(Key "database.type"'s value must be a string)");
EXPECT_TRUE(databaseConstraint.checkConstraint("123.44").has_value());
EXPECT_EQ(
databaseConstraint.checkConstraint("123.44")->error,
R"(You provided value "123.44". Key "database.type"'s value must be one of the following: tracer)"
);
}
TEST(ConfigValue, OneOfConstraint)
{
std::array<char const*, 3> const arr = {"123", "trace", "haha"};
auto const oneOfCons{OneOf{"log_level", arr}};
EXPECT_FALSE(oneOfCons.checkConstraint("trace").has_value());
EXPECT_TRUE(oneOfCons.checkConstraint(345).has_value());
EXPECT_EQ(oneOfCons.checkConstraint(345)->error, R"(Key "log_level"'s value must be a string)");
EXPECT_TRUE(oneOfCons.checkConstraint("PETER_WAS_HERE").has_value());
EXPECT_EQ(
oneOfCons.checkConstraint("PETER_WAS_HERE")->error,
R"(You provided value "PETER_WAS_HERE". Key "log_level"'s value must be one of the following: 123, trace, haha)"
);
}
TEST(ConfigValue, IpConstraint)
{
auto ip = ConfigValue{ConfigType::String}.defaultValue("127.0.0.1").withConstraint(validateIP);
EXPECT_FALSE(ip.setValue("http://127.0.0.1").has_value());
EXPECT_FALSE(ip.setValue("http://127.0.0.1.com").has_value());
auto const err = ip.setValue("123.44");
EXPECT_TRUE(err.has_value());
EXPECT_EQ(err->error, "Ip is not a valid ip address");
EXPECT_FALSE(ip.setValue("126.0.0.2"));
EXPECT_TRUE(ip.setValue("644.3.3.0"));
EXPECT_TRUE(ip.setValue("127.0.0.1.0"));
EXPECT_TRUE(ip.setValue(""));
EXPECT_TRUE(ip.setValue("http://example..com"));
EXPECT_FALSE(ip.setValue("localhost"));
EXPECT_FALSE(ip.setValue("http://example.com:8080/path"));
}
TEST(ConfigValue, positiveNumConstraint)
{
auto const numCons{NumberValueConstraint{0, 5}};
EXPECT_FALSE(numCons.checkConstraint(0));
EXPECT_FALSE(numCons.checkConstraint(5));
EXPECT_TRUE(numCons.checkConstraint(true));
EXPECT_EQ(numCons.checkConstraint(true)->error, fmt::format("Number must be of type integer"));
EXPECT_TRUE(numCons.checkConstraint(8));
EXPECT_EQ(numCons.checkConstraint(8)->error, fmt::format("Number must be between {} and {}", 0, 5));
}
TEST(ConfigValue, SetValuesOnNumberConstraint)
{
auto positiveNum = ConfigValue{ConfigType::Integer}.defaultValue(20u).withConstraint(validateUint16);
auto const err = positiveNum.setValue(-22, "key");
EXPECT_TRUE(err.has_value());
EXPECT_EQ(err->error, fmt::format("key Number must be between {} and {}", 0, 65535));
EXPECT_FALSE(positiveNum.setValue(99, "key"));
}
TEST(ConfigValue, PositiveDoubleConstraint)
{
auto const doubleCons{PositiveDouble{}};
EXPECT_FALSE(doubleCons.checkConstraint(0.2));
EXPECT_FALSE(doubleCons.checkConstraint(5.54));
EXPECT_TRUE(doubleCons.checkConstraint("-5"));
EXPECT_EQ(doubleCons.checkConstraint("-5")->error, "Double number must be of type int or double");
EXPECT_EQ(doubleCons.checkConstraint(-5.6)->error, "Double number must be greater than 0");
EXPECT_FALSE(doubleCons.checkConstraint(12.1));
}
struct ConstraintTestBundle {
std::string name;
Constraint const& cons_;
};
struct ConstraintDeathTest : public testing::Test, public testing::WithParamInterface<ConstraintTestBundle> {};
INSTANTIATE_TEST_SUITE_P(
EachConstraints,
ConstraintDeathTest,
testing::Values(
ConstraintTestBundle{"logTagConstraint", validateLogTag},
ConstraintTestBundle{"portConstraint", validatePort},
ConstraintTestBundle{"ipConstraint", validateIP},
ConstraintTestBundle{"channelConstraint", validateChannelName},
ConstraintTestBundle{"logLevelConstraint", validateLogLevelName},
ConstraintTestBundle{"cannsandraNameCnstraint", validateCassandraName},
ConstraintTestBundle{"loadModeConstraint", validateLoadMode},
ConstraintTestBundle{"ChannelNameConstraint", validateChannelName},
ConstraintTestBundle{"ApiVersionConstraint", validateApiVersion},
ConstraintTestBundle{"Uint16Constraint", validateUint16},
ConstraintTestBundle{"Uint32Constraint", validateUint32},
ConstraintTestBundle{"PositiveDoubleConstraint", validatePositiveDouble}
),
[](testing::TestParamInfo<ConstraintTestBundle> const& info) { return info.param.name; }
);
TEST_P(ConstraintDeathTest, TestEachConstraint)
{
EXPECT_DEATH(
{
[[maybe_unused]] auto const a =
ConfigValue{ConfigType::Boolean}.defaultValue(true).withConstraint(GetParam().cons_);
},
".*"
);
}
TEST(ConfigValueDeathTest, SetInvalidValueTypeStringAndBool)
{
EXPECT_DEATH(
{
[[maybe_unused]] auto a = ConfigValue{ConfigType::String}.defaultValue(33).withConstraint(validateLoadMode);
},
".*"
);
EXPECT_DEATH({ [[maybe_unused]] auto a = ConfigValue{ConfigType::Boolean}.defaultValue(-66); }, ".*");
}
TEST(ConfigValueDeathTest, OutOfBounceIntegerConstraint)
{
EXPECT_DEATH(
{
[[maybe_unused]] auto a =
ConfigValue{ConfigType::Integer}.defaultValue(999999).withConstraint(validateUint16);
},
".*"
);
EXPECT_DEATH(
{
[[maybe_unused]] auto a = ConfigValue{ConfigType::Integer}.defaultValue(-66).withConstraint(validateUint32);
},
".*"
);
}

View File

@@ -0,0 +1,113 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/TmpFile.hpp"
#include "util/newconfig/ConfigFileJson.hpp"
#include "util/newconfig/FakeConfigData.hpp"
#include <boost/json/parse.hpp>
#include <gtest/gtest.h>
#include <algorithm>
#include <cstdint>
#include <string>
#include <variant>
#include <vector>
TEST(CreateConfigFile, filePath)
{
auto const jsonFileObj = ConfigFileJson::make_ConfigFileJson(TmpFile(JSONData).path);
EXPECT_TRUE(jsonFileObj.has_value());
EXPECT_TRUE(jsonFileObj->containsKey("array.[].sub"));
auto const arrSub = jsonFileObj->getArray("array.[].sub");
EXPECT_EQ(arrSub.size(), 3);
}
TEST(CreateConfigFile, incorrectFilePath)
{
auto const jsonFileObj = util::config::ConfigFileJson::make_ConfigFileJson("123/clio");
EXPECT_FALSE(jsonFileObj.has_value());
}
struct ParseJson : testing::Test {
ParseJson() : jsonFileObj{boost::json::parse(JSONData).as_object()}
{
}
ConfigFileJson const jsonFileObj;
};
TEST_F(ParseJson, validateValues)
{
EXPECT_TRUE(jsonFileObj.containsKey("header.text1"));
EXPECT_EQ(std::get<std::string>(jsonFileObj.getValue("header.text1")), "value");
EXPECT_TRUE(jsonFileObj.containsKey("header.sub.sub2Value"));
EXPECT_EQ(std::get<std::string>(jsonFileObj.getValue("header.sub.sub2Value")), "TSM");
EXPECT_TRUE(jsonFileObj.containsKey("dosguard.port"));
EXPECT_EQ(std::get<int64_t>(jsonFileObj.getValue("dosguard.port")), 44444);
EXPECT_FALSE(jsonFileObj.containsKey("idk"));
EXPECT_FALSE(jsonFileObj.containsKey("optional.withNoDefault"));
}
TEST_F(ParseJson, validateArrayValue)
{
// validate array.[].sub matches expected values
EXPECT_TRUE(jsonFileObj.containsKey("array.[].sub"));
auto const arrSub = jsonFileObj.getArray("array.[].sub");
EXPECT_EQ(arrSub.size(), 3);
std::vector<double> expectedArrSubVal{111.11, 4321.55, 5555.44};
std::vector<double> actualArrSubVal{};
for (auto it = arrSub.begin(); it != arrSub.end(); ++it) {
ASSERT_TRUE(std::holds_alternative<double>(*it));
actualArrSubVal.emplace_back(std::get<double>(*it));
}
EXPECT_TRUE(std::ranges::equal(expectedArrSubVal, actualArrSubVal));
// validate array.[].sub2 matches expected values
EXPECT_TRUE(jsonFileObj.containsKey("array.[].sub2"));
auto const arrSub2 = jsonFileObj.getArray("array.[].sub2");
EXPECT_EQ(arrSub2.size(), 3);
std::vector<std::string> expectedArrSub2Val{"subCategory", "temporary", "london"};
std::vector<std::string> actualArrSub2Val{};
for (auto it = arrSub2.begin(); it != arrSub2.end(); ++it) {
ASSERT_TRUE(std::holds_alternative<std::string>(*it));
actualArrSub2Val.emplace_back(std::get<std::string>(*it));
}
EXPECT_TRUE(std::ranges::equal(expectedArrSub2Val, actualArrSub2Val));
EXPECT_TRUE(jsonFileObj.containsKey("dosguard.whitelist.[]"));
auto const whitelistArr = jsonFileObj.getArray("dosguard.whitelist.[]");
EXPECT_EQ(whitelistArr.size(), 2);
EXPECT_EQ("125.5.5.1", std::get<std::string>(whitelistArr.at(0)));
EXPECT_EQ("204.2.2.1", std::get<std::string>(whitelistArr.at(1)));
}
struct JsonValueDeathTest : ParseJson {};
TEST_F(JsonValueDeathTest, invalidGetArray)
{
EXPECT_DEATH([[maybe_unused]] auto a = jsonFileObj.getArray("header.text1"), ".*");
}

View File

@@ -19,15 +19,23 @@
#include "util/newconfig/ArrayView.hpp"
#include "util/newconfig/ConfigDefinition.hpp"
#include "util/newconfig/ConfigFileJson.hpp"
#include "util/newconfig/FakeConfigData.hpp"
#include "util/newconfig/ObjectView.hpp"
#include <boost/json/parse.hpp>
#include <gtest/gtest.h>
using namespace util::config;
struct ObjectViewTest : testing::Test {
ClioConfigDefinition const configData = generateConfig();
ObjectViewTest()
{
ConfigFileJson const jsonFileObj{boost::json::parse(JSONData).as_object()};
auto const errors = configData.parse(jsonFileObj);
EXPECT_TRUE(!errors.has_value());
}
ClioConfigDefinition configData = generateConfig();
};
TEST_F(ObjectViewTest, ObjectValueTest)
@@ -39,14 +47,14 @@ TEST_F(ObjectViewTest, ObjectValueTest)
EXPECT_TRUE(headerObj.containsKey("admin"));
EXPECT_EQ("value", headerObj.getValue("text1").asString());
EXPECT_EQ(123, headerObj.getValue("port").asIntType<int>());
EXPECT_EQ(true, headerObj.getValue("admin").asBool());
EXPECT_EQ(321, headerObj.getValue("port").asIntType<int>());
EXPECT_EQ(false, headerObj.getValue("admin").asBool());
}
TEST_F(ObjectViewTest, ObjectInArray)
TEST_F(ObjectViewTest, ObjectValuesInArray)
{
ArrayView const arr = configData.getArray("array");
EXPECT_EQ(arr.size(), 2);
EXPECT_EQ(arr.size(), 3);
ObjectView const firstObj = arr.objectAt(0);
ObjectView const secondObj = arr.objectAt(1);
EXPECT_TRUE(firstObj.containsKey("sub"));
@@ -62,7 +70,7 @@ TEST_F(ObjectViewTest, ObjectInArray)
EXPECT_EQ(secondObj.getValue("sub2").asString(), "temporary");
}
TEST_F(ObjectViewTest, ObjectInArrayMoreComplex)
TEST_F(ObjectViewTest, GetObjectsInDifferentWays)
{
ArrayView const arr = configData.getArray("higher");
ASSERT_EQ(1, arr.size());
@@ -78,7 +86,7 @@ TEST_F(ObjectViewTest, ObjectInArrayMoreComplex)
ObjectView const objLow = firstObj.getObject("low");
EXPECT_TRUE(objLow.containsKey("section"));
EXPECT_TRUE(objLow.containsKey("admin"));
EXPECT_EQ(objLow.getValue("section").asString(), "true");
EXPECT_EQ(objLow.getValue("section").asString(), "WebServer");
EXPECT_EQ(objLow.getValue("admin").asBool(), false);
}
@@ -90,18 +98,25 @@ TEST_F(ObjectViewTest, getArrayInObject)
auto const arr = obj.getArray("whitelist");
EXPECT_EQ(2, arr.size());
EXPECT_EQ("125.5.5.2", arr.valueAt(0).asString());
EXPECT_EQ("204.2.2.2", arr.valueAt(1).asString());
EXPECT_EQ("125.5.5.1", arr.valueAt(0).asString());
EXPECT_EQ("204.2.2.1", arr.valueAt(1).asString());
}
struct ObjectViewDeathTest : ObjectViewTest {};
TEST_F(ObjectViewDeathTest, incorrectKeys)
TEST_F(ObjectViewDeathTest, KeyDoesNotExist)
{
EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("head"); }, ".*");
}
TEST_F(ObjectViewDeathTest, KeyIsValueView)
{
EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("header.text1"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("head"); }, ".*");
EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getArray("header"); }, ".*");
}
TEST_F(ObjectViewDeathTest, KeyisArrayView)
{
// dies because only 1 object in higher.[].low
EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("higher.[].low", 1); }, ".*");
}

View File

@@ -20,6 +20,7 @@
#include "util/newconfig/ConfigDefinition.hpp"
#include "util/newconfig/ConfigValue.hpp"
#include "util/newconfig/FakeConfigData.hpp"
#include "util/newconfig/Types.hpp"
#include "util/newconfig/ValueView.hpp"
#include <gtest/gtest.h>

View File

@@ -25,6 +25,8 @@
#include <boost/asio/error.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/beast/http/field.hpp>
#include <boost/beast/websocket/stream.hpp>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <chrono>
@@ -33,6 +35,7 @@
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <thread>
#include <vector>
@@ -318,6 +321,40 @@ TEST_F(WsConnectionTests, MultipleConnections)
}
}
TEST_F(WsConnectionTests, RespondsToPing)
{
asio::spawn(ctx, [&](asio::yield_context yield) {
auto serverConnection = unwrap(server.acceptConnection(yield));
testing::StrictMock<testing::MockFunction<void(boost::beast::websocket::frame_type, std::string_view)>>
controlFrameCallback;
serverConnection.setControlFrameCallback(controlFrameCallback.AsStdFunction());
EXPECT_CALL(controlFrameCallback, Call(boost::beast::websocket::frame_type::pong, testing::_)).WillOnce([&]() {
serverConnection.resetControlFrameCallback();
asio::spawn(ctx, [&](asio::yield_context yield) {
auto maybeError = serverConnection.send("got pong", yield);
ASSERT_FALSE(maybeError.has_value()) << *maybeError;
});
});
serverConnection.sendPing({}, yield);
auto message = serverConnection.receive(yield);
ASSERT_TRUE(message.has_value());
EXPECT_EQ(message, "hello") << message.value();
});
runSpawn([&](asio::yield_context yield) {
auto connection = builder.plainConnect(yield);
ASSERT_TRUE(connection.has_value()) << connection.error().message();
auto expectedMessage = connection->operator*().read(yield);
ASSERT_TRUE(expectedMessage) << expectedMessage.error().message();
EXPECT_EQ(expectedMessage.value(), "got pong");
auto error = connection->operator*().write("hello", yield);
ASSERT_FALSE(error) << error->message();
});
}
enum class WsConnectionErrorTestsBundle : int { Read = 1, Write = 2 };
struct WsConnectionErrorTests : WsConnectionTestsBase, testing::WithParamInterface<WsConnectionErrorTestsBundle> {};

View File

@@ -0,0 +1,210 @@
//
// Based off of https://github.com/scylladb/scylla-code-samples/blob/master/efficient_full_table_scan_example_code/efficient_full_table_scan.go
//
package main
import (
"fmt"
"log"
"os"
"strings"
"time"
"xrplf/clio/cassandra_delete_range/internal/cass"
"xrplf/clio/cassandra_delete_range/internal/util"
"github.com/alecthomas/kingpin/v2"
"github.com/gocql/gocql"
)
const (
defaultNumberOfNodesInCluster = 3
defaultNumberOfCoresInNode = 8
defaultSmudgeFactor = 3
)
var (
app = kingpin.New("cassandra_delete_range", "A tool that prunes data from the Clio DB.")
hosts = app.Flag("hosts", "Your Scylla nodes IP addresses, comma separated (i.e. 192.168.1.1,192.168.1.2,192.168.1.3)").Required().String()
deleteAfter = app.Command("delete-after", "Prunes from the given ledger index until the end")
deleteAfterLedgerIdx = deleteAfter.Arg("idx", "Sets the earliest ledger_index to keep untouched (delete everything after this ledger index)").Required().Uint64()
deleteBefore = app.Command("delete-before", "Prunes everything before the given ledger index")
deleteBeforeLedgerIdx = deleteBefore.Arg("idx", "Sets the latest ledger_index to keep around (delete everything before this ledger index)").Required().Uint64()
getLedgerRange = app.Command("get-ledger-range", "Fetch the current lender_range table values")
nodesInCluster = app.Flag("nodes-in-cluster", "Number of nodes in your Scylla cluster").Short('n').Default(fmt.Sprintf("%d", defaultNumberOfNodesInCluster)).Int()
coresInNode = app.Flag("cores-in-node", "Number of cores in each node").Short('c').Default(fmt.Sprintf("%d", defaultNumberOfCoresInNode)).Int()
smudgeFactor = app.Flag("smudge-factor", "Yet another factor to make parallelism cooler").Short('s').Default(fmt.Sprintf("%d", defaultSmudgeFactor)).Int()
clusterConsistency = app.Flag("consistency", "Cluster consistency level. Use 'localone' for multi DC").Short('o').Default("localquorum").String()
clusterTimeout = app.Flag("timeout", "Maximum duration for query execution in millisecond").Short('t').Default("15000").Int()
clusterNumConnections = app.Flag("cluster-number-of-connections", "Number of connections per host per session (in our case, per thread)").Short('b').Default("1").Int()
clusterCQLVersion = app.Flag("cql-version", "The CQL version to use").Short('l').Default("3.0.0").String()
clusterPageSize = app.Flag("cluster-page-size", "Page size of results").Short('p').Default("5000").Int()
keyspace = app.Flag("keyspace", "Keyspace to use").Short('k').Default("clio_fh").String()
userName = app.Flag("username", "Username to use when connecting to the cluster").String()
password = app.Flag("password", "Password to use when connecting to the cluster").String()
skipSuccessorTable = app.Flag("skip-successor", "Whether to skip deletion from successor table").Default("false").Bool()
skipObjectsTable = app.Flag("skip-objects", "Whether to skip deletion from objects table").Default("false").Bool()
skipLedgerHashesTable = app.Flag("skip-ledger-hashes", "Whether to skip deletion from ledger_hashes table").Default("false").Bool()
skipTransactionsTable = app.Flag("skip-transactions", "Whether to skip deletion from transactions table").Default("false").Bool()
skipDiffTable = app.Flag("skip-diff", "Whether to skip deletion from diff table").Default("false").Bool()
skipLedgerTransactionsTable = app.Flag("skip-ledger-transactions", "Whether to skip deletion from ledger_transactions table").Default("false").Bool()
skipLedgersTable = app.Flag("skip-ledgers", "Whether to skip deletion from ledgers table").Default("false").Bool()
skipWriteLatestLedger = app.Flag("skip-write-latest-ledger", "Whether to skip writing the latest ledger index").Default("false").Bool()
workerCount = 1 // the calculated number of parallel goroutines the client should run
ranges []*util.TokenRange // the calculated ranges to be executed in parallel
)
func main() {
log.SetOutput(os.Stdout)
command := kingpin.MustParse(app.Parse(os.Args[1:]))
cluster, err := prepareDb(hosts)
if err != nil {
log.Fatal(err)
}
clioCass := cass.NewClioCass(&cass.Settings{
SkipSuccessorTable: *skipSuccessorTable,
SkipObjectsTable: *skipObjectsTable,
SkipLedgerHashesTable: *skipLedgerHashesTable,
SkipTransactionsTable: *skipTransactionsTable,
SkipDiffTable: *skipDiffTable,
SkipLedgerTransactionsTable: *skipLedgerHashesTable,
SkipLedgersTable: *skipLedgersTable,
SkipWriteLatestLedger: *skipWriteLatestLedger,
WorkerCount: workerCount,
Ranges: ranges}, cluster)
switch command {
case deleteAfter.FullCommand():
if *deleteAfterLedgerIdx == 0 {
log.Println("Please specify ledger index to delete from")
return
}
displayParams("delete-after", hosts, cluster.Timeout/1000/1000, *deleteAfterLedgerIdx)
log.Printf("Will delete everything after ledger index %d (exclusive) and till latest\n", *deleteAfterLedgerIdx)
log.Println("WARNING: Please make sure that there are no Clio writers operating on the DB while this script is running")
if !util.PromptContinue() {
log.Fatal("Aborted")
}
startTime := time.Now().UTC()
clioCass.DeleteAfter(*deleteAfterLedgerIdx)
fmt.Printf("Total Execution Time: %s\n\n", time.Since(startTime))
fmt.Println("NOTE: Cassandra/ScyllaDB only writes tombstones. You need to run compaction to free up disk space.")
case deleteBefore.FullCommand():
if *deleteBeforeLedgerIdx == 0 {
log.Println("Please specify ledger index to delete until")
return
}
displayParams("delete-before", hosts, cluster.Timeout/1000/1000, *deleteBeforeLedgerIdx)
log.Printf("Will delete everything before ledger index %d (exclusive)\n", *deleteBeforeLedgerIdx)
log.Println("WARNING: Please make sure that there are no Clio writers operating on the DB while this script is running")
if !util.PromptContinue() {
log.Fatal("Aborted")
}
startTime := time.Now().UTC()
clioCass.DeleteBefore(*deleteBeforeLedgerIdx)
fmt.Printf("Total Execution Time: %s\n\n", time.Since(startTime))
fmt.Println("NOTE: Cassandra/ScyllaDB only writes tombstones. You need to run compaction to free up disk space.")
case getLedgerRange.FullCommand():
from, to, err := clioCass.GetLedgerRange()
if err != nil {
log.Fatal(err)
}
fmt.Printf("Range: %d -> %d\n", from, to)
}
}
func displayParams(command string, hosts *string, timeout time.Duration, ledgerIdx uint64) {
runParameters := fmt.Sprintf(`
Execution Parameters:
=====================
Command : %s
Ledger index : %d
Scylla cluster nodes : %s
Keyspace : %s
Consistency : %s
Timeout (ms) : %d
Connections per host : %d
CQL Version : %s
Page size : %d
# of parallel threads : %d
# of ranges to be executed : %d
Skip deletion of:
- successor table : %t
- objects table : %t
- ledger_hashes table : %t
- transactions table : %t
- diff table : %t
- ledger_transactions table : %t
- ledgers table : %t
Will update ledger_range : %t
`,
command,
ledgerIdx,
*hosts,
*keyspace,
*clusterConsistency,
timeout,
*clusterNumConnections,
*clusterCQLVersion,
*clusterPageSize,
workerCount,
len(ranges),
*skipSuccessorTable || command == "delete-before",
*skipObjectsTable,
*skipLedgerHashesTable,
*skipTransactionsTable,
*skipDiffTable,
*skipLedgerTransactionsTable,
*skipLedgersTable,
!*skipWriteLatestLedger)
fmt.Println(runParameters)
}
func prepareDb(dbHosts *string) (*gocql.ClusterConfig, error) {
workerCount = (*nodesInCluster) * (*coresInNode) * (*smudgeFactor)
ranges = util.GetTokenRanges(workerCount)
util.Shuffle(ranges)
hosts := strings.Split(*dbHosts, ",")
cluster := gocql.NewCluster(hosts...)
cluster.Consistency = util.GetConsistencyLevel(*clusterConsistency)
cluster.Timeout = time.Duration(*clusterTimeout * 1000 * 1000)
cluster.NumConns = *clusterNumConnections
cluster.CQLVersion = *clusterCQLVersion
cluster.PageSize = *clusterPageSize
cluster.Keyspace = *keyspace
if *userName != "" {
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: *userName,
Password: *password,
}
}
return cluster, nil
}

View File

@@ -3,9 +3,13 @@ module xrplf/clio/cassandra_delete_range
go 1.21.6
require (
github.com/alecthomas/kingpin/v2 v2.4.0 // indirect
github.com/alecthomas/kingpin/v2 v2.4.0
github.com/gocql/gocql v1.6.0
github.com/pmorelli92/maybe v1.1.0
)
require (
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect
github.com/gocql/gocql v1.6.0 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/xhit/go-str2duration/v2 v2.1.0 // indirect

View File

@@ -2,25 +2,38 @@ github.com/alecthomas/kingpin/v2 v2.4.0 h1:f48lwail6p8zpO1bC4TxtqACaGqHYA22qkHjH
github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE=
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAuRjVTiNNhvNRfY2Wxp9nhfyel4rklc=
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU=
github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmorelli92/maybe v1.1.0 h1:uyV6NLF4453AQARZ6rKpJNzc9PBsQmpGDtUonhxInPU=
github.com/pmorelli92/maybe v1.1.0/go.mod h1:5PrW2+fo4/j/LMX6HT49Hb3/HOKv1tbodkzgy4lEopA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,564 @@
package cass
import (
"fmt"
"log"
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"xrplf/clio/cassandra_delete_range/internal/util"
"github.com/gocql/gocql"
"github.com/pmorelli92/maybe"
)
type deleteInfo struct {
Query string
Data []deleteParams
}
type deleteParams struct {
Seq uint64
Blob []byte // hash, key, etc
}
type columnSettings struct {
UseSeq bool
UseBlob bool
}
type Settings struct {
SkipSuccessorTable bool
SkipObjectsTable bool
SkipLedgerHashesTable bool
SkipTransactionsTable bool
SkipDiffTable bool
SkipLedgerTransactionsTable bool
SkipLedgersTable bool
SkipWriteLatestLedger bool
WorkerCount int
Ranges []*util.TokenRange
}
type Cass interface {
GetLedgerRange() (uint64, uint64, error)
DeleteBefore(ledgerIdx uint64)
DeleteAfter(ledgerIdx uint64)
}
type ClioCass struct {
settings *Settings
clusterConfig *gocql.ClusterConfig
}
func NewClioCass(settings *Settings, cluster *gocql.ClusterConfig) *ClioCass {
return &ClioCass{settings, cluster}
}
func (c *ClioCass) DeleteBefore(ledgerIdx uint64) {
firstLedgerIdxInDB, latestLedgerIdxInDB, err := c.GetLedgerRange()
if err != nil {
log.Fatal(err)
}
log.Printf("DB ledger range is %d -> %d\n", firstLedgerIdxInDB, latestLedgerIdxInDB)
if firstLedgerIdxInDB > ledgerIdx {
log.Fatal("Earliest ledger index in DB is greater than the one specified. Aborting...")
}
if latestLedgerIdxInDB < ledgerIdx {
log.Fatal("Latest ledger index in DB is smaller than the one specified. Aborting...")
}
var (
from maybe.Maybe[uint64] // not used
to maybe.Maybe[uint64] = maybe.Set(ledgerIdx - 1)
)
c.settings.SkipSuccessorTable = true // skip successor update until we know how to do it
if err := c.pruneData(from, to, firstLedgerIdxInDB, latestLedgerIdxInDB); err != nil {
log.Fatal(err)
}
}
func (c *ClioCass) DeleteAfter(ledgerIdx uint64) {
firstLedgerIdxInDB, latestLedgerIdxInDB, err := c.GetLedgerRange()
if err != nil {
log.Fatal(err)
}
log.Printf("DB ledger range is %d -> %d\n", firstLedgerIdxInDB, latestLedgerIdxInDB)
if firstLedgerIdxInDB > ledgerIdx {
log.Fatal("Earliest ledger index in DB is greater than the one specified. Aborting...")
}
if latestLedgerIdxInDB < ledgerIdx {
log.Fatal("Latest ledger index in DB is smaller than the one specified. Aborting...")
}
var (
from maybe.Maybe[uint64] = maybe.Set(ledgerIdx + 1)
to maybe.Maybe[uint64] // not used
)
if err := c.pruneData(from, to, firstLedgerIdxInDB, latestLedgerIdxInDB); err != nil {
log.Fatal(err)
}
}
func (c *ClioCass) GetLedgerRange() (uint64, uint64, error) {
var (
firstLedgerIdx uint64
latestLedgerIdx uint64
)
session, err := c.clusterConfig.CreateSession()
if err != nil {
log.Fatal(err)
}
defer session.Close()
if err := session.Query("SELECT sequence FROM ledger_range WHERE is_latest = ?", false).Scan(&firstLedgerIdx); err != nil {
return 0, 0, err
}
if err := session.Query("SELECT sequence FROM ledger_range WHERE is_latest = ?", true).Scan(&latestLedgerIdx); err != nil {
return 0, 0, err
}
return firstLedgerIdx, latestLedgerIdx, nil
}
func (c *ClioCass) pruneData(
fromLedgerIdx maybe.Maybe[uint64],
toLedgerIdx maybe.Maybe[uint64],
firstLedgerIdxInDB uint64,
latestLedgerIdxInDB uint64,
) error {
var totalErrors uint64
var totalRows uint64
var totalDeletes uint64
var info deleteInfo
var rowsCount uint64
var deleteCount uint64
var errCount uint64
// calculate range of simple delete queries
var (
rangeFrom uint64 = firstLedgerIdxInDB
rangeTo uint64 = latestLedgerIdxInDB
)
if fromLedgerIdx.HasValue() {
rangeFrom = fromLedgerIdx.Value()
}
if toLedgerIdx.HasValue() {
rangeTo = toLedgerIdx.Value()
}
// calculate and print deletion plan
fromStr := "beginning"
if fromLedgerIdx.HasValue() {
fromStr = strconv.Itoa(int(fromLedgerIdx.Value()))
}
toStr := "latest"
if toLedgerIdx.HasValue() {
toStr = strconv.Itoa(int(toLedgerIdx.Value()))
}
log.Printf("Start scanning and removing data for %s -> %s\n\n", fromStr, toStr)
// successor queries
if !c.settings.SkipSuccessorTable {
log.Println("Generating delete queries for successor table")
info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
"SELECT key, seq FROM successor WHERE token(key) >= ? AND token(key) <= ?",
"DELETE FROM successor WHERE key = ? AND seq = ?", false)
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// objects queries
if !c.settings.SkipObjectsTable {
log.Println("Generating delete queries for objects table")
info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
"SELECT key, sequence FROM objects WHERE token(key) >= ? AND token(key) <= ?",
"DELETE FROM objects WHERE key = ? AND sequence = ?", true)
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// ledger_hashes queries
if !c.settings.SkipLedgerHashesTable {
log.Println("Generating delete queries for ledger_hashes table")
info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
"SELECT hash, sequence FROM ledger_hashes WHERE token(hash) >= ? AND token(hash) <= ?",
"DELETE FROM ledger_hashes WHERE hash = ?", false)
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: false})
totalErrors += errCount
totalDeletes += deleteCount
}
// transactions queries
if !c.settings.SkipTransactionsTable {
log.Println("Generating delete queries for transactions table")
info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
"SELECT hash, ledger_sequence FROM transactions WHERE token(hash) >= ? AND token(hash) <= ?",
"DELETE FROM transactions WHERE hash = ?", false)
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: false})
totalErrors += errCount
totalDeletes += deleteCount
}
// diff queries
if !c.settings.SkipDiffTable {
log.Println("Generating delete queries for diff table")
info = c.prepareSimpleDeleteQueries(rangeFrom, rangeTo,
"DELETE FROM diff WHERE seq = ?")
log.Printf("Total delete queries: %d\n\n", len(info.Data))
deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: false, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// ledger_transactions queries
if !c.settings.SkipLedgerTransactionsTable {
log.Println("Generating delete queries for ledger_transactions table")
info = c.prepareSimpleDeleteQueries(rangeFrom, rangeTo,
"DELETE FROM ledger_transactions WHERE ledger_sequence = ?")
log.Printf("Total delete queries: %d\n\n", len(info.Data))
deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: false, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// ledgers queries
if !c.settings.SkipLedgersTable {
log.Println("Generating delete queries for ledgers table")
info = c.prepareSimpleDeleteQueries(rangeFrom, rangeTo,
"DELETE FROM ledgers WHERE sequence = ?")
log.Printf("Total delete queries: %d\n\n", len(info.Data))
deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: false, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// TODO: tbd what to do with account_tx as it got tuple for seq_idx
// TODO: also, whether we need to take care of nft tables and other stuff like that
if !c.settings.SkipWriteLatestLedger {
var (
first maybe.Maybe[uint64]
last maybe.Maybe[uint64]
)
if fromLedgerIdx.HasValue() {
last = maybe.Set(fromLedgerIdx.Value() - 1)
}
if toLedgerIdx.HasValue() {
first = maybe.Set(toLedgerIdx.Value() + 1)
}
if err := c.updateLedgerRange(first, last); err != nil {
log.Printf("ERROR failed updating ledger range: %s\n", err)
return err
}
}
log.Printf("TOTAL ERRORS: %d\n", totalErrors)
log.Printf("TOTAL ROWS TRAVERSED: %d\n", totalRows)
log.Printf("TOTAL DELETES: %d\n\n", totalDeletes)
log.Printf("Completed deletion for %s -> %s\n\n", fromStr, toStr)
return nil
}
func (c *ClioCass) prepareSimpleDeleteQueries(
fromLedgerIdx uint64,
toLedgerIdx uint64,
deleteQueryTemplate string,
) deleteInfo {
var info = deleteInfo{Query: deleteQueryTemplate}
for i := fromLedgerIdx; i <= toLedgerIdx; i++ {
info.Data = append(info.Data, deleteParams{Seq: i})
}
return info
}
func (c *ClioCass) prepareDeleteQueries(
fromLedgerIdx maybe.Maybe[uint64],
toLedgerIdx maybe.Maybe[uint64],
queryTemplate string,
deleteQueryTemplate string,
keepLastValid bool,
) (deleteInfo, uint64, uint64) {
rangesChannel := make(chan *util.TokenRange, len(c.settings.Ranges))
for i := range c.settings.Ranges {
rangesChannel <- c.settings.Ranges[i]
}
close(rangesChannel)
outChannel := make(chan deleteParams)
var info = deleteInfo{Query: deleteQueryTemplate}
go func() {
total := uint64(0)
for params := range outChannel {
total += 1
if total%1000 == 0 {
log.Printf("... %d queries ...\n", total)
}
info.Data = append(info.Data, params)
}
}()
var wg sync.WaitGroup
var sessionCreationWaitGroup sync.WaitGroup
var totalRows uint64
var totalErrors uint64
wg.Add(c.settings.WorkerCount)
sessionCreationWaitGroup.Add(c.settings.WorkerCount)
for i := 0; i < c.settings.WorkerCount; i++ {
go func(q string) {
defer wg.Done()
var session *gocql.Session
var err error
if session, err = c.clusterConfig.CreateSession(); err == nil {
defer session.Close()
sessionCreationWaitGroup.Done()
sessionCreationWaitGroup.Wait()
preparedQuery := session.Query(q)
for r := range rangesChannel {
preparedQuery.Bind(r.StartRange, r.EndRange)
var pageState []byte
var rowsRetrieved uint64
var previousKey []byte
var foundLastValid bool
for {
iter := preparedQuery.PageSize(c.clusterConfig.PageSize).PageState(pageState).Iter()
nextPageState := iter.PageState()
scanner := iter.Scanner()
for scanner.Next() {
var key []byte
var seq uint64
err = scanner.Scan(&key, &seq)
if err == nil {
rowsRetrieved++
if keepLastValid && !slices.Equal(previousKey, key) {
previousKey = key
foundLastValid = false
}
// only grab the rows that are in the correct range of sequence numbers
if fromLedgerIdx.HasValue() && fromLedgerIdx.Value() <= seq {
outChannel <- deleteParams{Seq: seq, Blob: key}
} else if toLedgerIdx.HasValue() {
if seq < toLedgerIdx.Value() && (!keepLastValid || foundLastValid) {
outChannel <- deleteParams{Seq: seq, Blob: key}
} else if seq <= toLedgerIdx.Value()+1 {
foundLastValid = true
}
}
} else {
log.Printf("ERROR: page iteration failed: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [from=%d][to=%d][pagestate=%x]", queryTemplate, r.StartRange, r.EndRange, pageState))
atomic.AddUint64(&totalErrors, 1)
}
}
if len(nextPageState) == 0 {
break
}
pageState = nextPageState
}
atomic.AddUint64(&totalRows, rowsRetrieved)
}
} else {
log.Printf("ERROR: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
atomic.AddUint64(&totalErrors, 1)
}
}(queryTemplate)
}
wg.Wait()
close(outChannel)
return info, totalRows, totalErrors
}
func (c *ClioCass) splitDeleteWork(info *deleteInfo) [][]deleteParams {
var n = c.settings.WorkerCount
var chunkSize = len(info.Data) / n
var chunks [][]deleteParams
if len(info.Data) == 0 {
return chunks
}
if chunkSize < 1 {
chunks = append(chunks, info.Data)
return chunks
}
for i := 0; i < len(info.Data); i += chunkSize {
end := i + chunkSize
if end > len(info.Data) {
end = len(info.Data)
}
chunks = append(chunks, info.Data[i:end])
}
return chunks
}
func (c *ClioCass) performDeleteQueries(info *deleteInfo, colSettings columnSettings) (uint64, uint64) {
var wg sync.WaitGroup
var sessionCreationWaitGroup sync.WaitGroup
var totalDeletes uint64
var totalErrors uint64
chunks := c.splitDeleteWork(info)
chunksChannel := make(chan []deleteParams, len(chunks))
for i := range chunks {
chunksChannel <- chunks[i]
}
close(chunksChannel)
wg.Add(c.settings.WorkerCount)
sessionCreationWaitGroup.Add(c.settings.WorkerCount)
query := info.Query
bindCount := strings.Count(query, "?")
for i := 0; i < c.settings.WorkerCount; i++ {
go func(number int, q string, bc int) {
defer wg.Done()
var session *gocql.Session
var err error
if session, err = c.clusterConfig.CreateSession(); err == nil {
defer session.Close()
sessionCreationWaitGroup.Done()
sessionCreationWaitGroup.Wait()
preparedQuery := session.Query(q)
for chunk := range chunksChannel {
for _, r := range chunk {
if bc == 2 {
preparedQuery.Bind(r.Blob, r.Seq)
} else if bc == 1 {
if colSettings.UseSeq {
preparedQuery.Bind(r.Seq)
} else if colSettings.UseBlob {
preparedQuery.Bind(r.Blob)
}
}
if err := preparedQuery.Exec(); err != nil {
log.Printf("DELETE ERROR: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [blob=0x%x][seq=%d]", info.Query, r.Blob, r.Seq))
atomic.AddUint64(&totalErrors, 1)
} else {
atomic.AddUint64(&totalDeletes, 1)
if atomic.LoadUint64(&totalDeletes)%10000 == 0 {
log.Printf("... %d deletes ...\n", totalDeletes)
}
}
}
}
} else {
log.Printf("ERROR: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
atomic.AddUint64(&totalErrors, 1)
}
}(i, query, bindCount)
}
wg.Wait()
return totalDeletes, totalErrors
}
func (c *ClioCass) updateLedgerRange(newStartLedger maybe.Maybe[uint64], newEndLedger maybe.Maybe[uint64]) error {
if session, err := c.clusterConfig.CreateSession(); err == nil {
defer session.Close()
query := "UPDATE ledger_range SET sequence = ? WHERE is_latest = ?"
if newEndLedger.HasValue() {
log.Printf("Updating ledger range end to %d\n", newEndLedger.Value())
preparedQuery := session.Query(query, newEndLedger.Value(), true)
if err := preparedQuery.Exec(); err != nil {
fmt.Fprintf(os.Stderr, "FAILED QUERY: %s [seq=%d][true]\n", query, newEndLedger.Value())
return err
}
}
if newStartLedger.HasValue() {
log.Printf("Updating ledger range start to %d\n", newStartLedger.Value())
preparedQuery := session.Query(query, newStartLedger.Value(), false)
if err := preparedQuery.Exec(); err != nil {
fmt.Fprintf(os.Stderr, "FAILED QUERY: %s [seq=%d][false]\n", query, newStartLedger.Value())
return err
}
}
} else {
fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
return err
}
return nil
}

View File

@@ -0,0 +1,91 @@
package util
import (
"fmt"
"log"
"math"
"math/rand"
"github.com/gocql/gocql"
)
type TokenRange struct {
StartRange int64
EndRange int64
}
func Shuffle(data []*TokenRange) {
for i := 1; i < len(data); i++ {
r := rand.Intn(i + 1)
if i != r {
data[r], data[i] = data[i], data[r]
}
}
}
func PromptContinue() bool {
var continueFlag string
log.Println("Are you sure you want to continue? (y/n)")
if fmt.Scanln(&continueFlag); continueFlag != "y" {
return false
}
return true
}
func GetTokenRanges(workerCount int) []*TokenRange {
var n = workerCount
var m = int64(n * 100)
var maxSize uint64 = math.MaxInt64 * 2
var rangeSize = maxSize / uint64(m)
var start int64 = math.MinInt64
var end int64
var shouldBreak = false
var ranges = make([]*TokenRange, m)
for i := int64(0); i < m; i++ {
end = start + int64(rangeSize)
if start > 0 && end < 0 {
end = math.MaxInt64
shouldBreak = true
}
ranges[i] = &TokenRange{StartRange: start, EndRange: end}
if shouldBreak {
break
}
start = end + 1
}
return ranges
}
func GetConsistencyLevel(consistencyValue string) gocql.Consistency {
switch consistencyValue {
case "any":
return gocql.Any
case "one":
return gocql.One
case "two":
return gocql.Two
case "three":
return gocql.Three
case "quorum":
return gocql.Quorum
case "all":
return gocql.All
case "localquorum":
return gocql.LocalQuorum
case "eachquorum":
return gocql.EachQuorum
case "localone":
return gocql.LocalOne
default:
return gocql.One
}
}

View File

@@ -1,619 +0,0 @@
//
// Based off of https://github.com/scylladb/scylla-code-samples/blob/master/efficient_full_table_scan_example_code/efficient_full_table_scan.go
//
package main
import (
"fmt"
"log"
"math"
"math/rand"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/alecthomas/kingpin/v2"
"github.com/gocql/gocql"
)
const (
defaultNumberOfNodesInCluster = 3
defaultNumberOfCoresInNode = 8
defaultSmudgeFactor = 3
)
var (
clusterHosts = kingpin.Arg("hosts", "Your Scylla nodes IP addresses, comma separated (i.e. 192.168.1.1,192.168.1.2,192.168.1.3)").Required().String()
earliestLedgerIdx = kingpin.Flag("ledgerIdx", "Sets the earliest ledger_index to keep untouched").Short('i').Required().Uint64()
nodesInCluster = kingpin.Flag("nodes-in-cluster", "Number of nodes in your Scylla cluster").Short('n').Default(fmt.Sprintf("%d", defaultNumberOfNodesInCluster)).Int()
coresInNode = kingpin.Flag("cores-in-node", "Number of cores in each node").Short('c').Default(fmt.Sprintf("%d", defaultNumberOfCoresInNode)).Int()
smudgeFactor = kingpin.Flag("smudge-factor", "Yet another factor to make parallelism cooler").Short('s').Default(fmt.Sprintf("%d", defaultSmudgeFactor)).Int()
clusterConsistency = kingpin.Flag("consistency", "Cluster consistency level. Use 'localone' for multi DC").Short('o').Default("localquorum").String()
clusterTimeout = kingpin.Flag("timeout", "Maximum duration for query execution in millisecond").Short('t').Default("15000").Int()
clusterNumConnections = kingpin.Flag("cluster-number-of-connections", "Number of connections per host per session (in our case, per thread)").Short('b').Default("1").Int()
clusterCQLVersion = kingpin.Flag("cql-version", "The CQL version to use").Short('l').Default("3.0.0").String()
clusterPageSize = kingpin.Flag("cluster-page-size", "Page size of results").Short('p').Default("5000").Int()
keyspace = kingpin.Flag("keyspace", "Keyspace to use").Short('k').Default("clio_fh").String()
userName = kingpin.Flag("username", "Username to use when connecting to the cluster").String()
password = kingpin.Flag("password", "Password to use when connecting to the cluster").String()
skipSuccessorTable = kingpin.Flag("skip-successor", "Whether to skip deletion from successor table").Default("false").Bool()
skipObjectsTable = kingpin.Flag("skip-objects", "Whether to skip deletion from objects table").Default("false").Bool()
skipLedgerHashesTable = kingpin.Flag("skip-ledger-hashes", "Whether to skip deletion from ledger_hashes table").Default("false").Bool()
skipTransactionsTable = kingpin.Flag("skip-transactions", "Whether to skip deletion from transactions table").Default("false").Bool()
skipDiffTable = kingpin.Flag("skip-diff", "Whether to skip deletion from diff table").Default("false").Bool()
skipLedgerTransactionsTable = kingpin.Flag("skip-ledger-transactions", "Whether to skip deletion from ledger_transactions table").Default("false").Bool()
skipLedgersTable = kingpin.Flag("skip-ledgers", "Whether to skip deletion from ledgers table").Default("false").Bool()
skipWriteLatestLedger = kingpin.Flag("skip-write-latest-ledger", "Whether to skip writing the latest ledger index").Default("false").Bool()
workerCount = 1 // the calculated number of parallel goroutines the client should run
ranges []*tokenRange // the calculated ranges to be executed in parallel
)
type tokenRange struct {
StartRange int64
EndRange int64
}
type deleteParams struct {
Seq uint64
Blob []byte // hash, key, etc
}
type columnSettings struct {
UseSeq bool
UseBlob bool
}
type deleteInfo struct {
Query string
Data []deleteParams
}
func getTokenRanges() []*tokenRange {
var n = workerCount
var m = int64(n * 100)
var maxSize uint64 = math.MaxInt64 * 2
var rangeSize = maxSize / uint64(m)
var start int64 = math.MinInt64
var end int64
var shouldBreak = false
var ranges = make([]*tokenRange, m)
for i := int64(0); i < m; i++ {
end = start + int64(rangeSize)
if start > 0 && end < 0 {
end = math.MaxInt64
shouldBreak = true
}
ranges[i] = &tokenRange{StartRange: start, EndRange: end}
if shouldBreak {
break
}
start = end + 1
}
return ranges
}
func splitDeleteWork(info *deleteInfo) [][]deleteParams {
var n = workerCount
var chunkSize = len(info.Data) / n
var chunks [][]deleteParams
if len(info.Data) == 0 {
return chunks
}
if chunkSize < 1 {
chunks = append(chunks, info.Data)
return chunks
}
for i := 0; i < len(info.Data); i += chunkSize {
end := i + chunkSize
if end > len(info.Data) {
end = len(info.Data)
}
chunks = append(chunks, info.Data[i:end])
}
return chunks
}
func shuffle(data []*tokenRange) {
for i := 1; i < len(data); i++ {
r := rand.Intn(i + 1)
if i != r {
data[r], data[i] = data[i], data[r]
}
}
}
func getConsistencyLevel(consistencyValue string) gocql.Consistency {
switch consistencyValue {
case "any":
return gocql.Any
case "one":
return gocql.One
case "two":
return gocql.Two
case "three":
return gocql.Three
case "quorum":
return gocql.Quorum
case "all":
return gocql.All
case "localquorum":
return gocql.LocalQuorum
case "eachquorum":
return gocql.EachQuorum
case "localone":
return gocql.LocalOne
default:
return gocql.One
}
}
func main() {
log.SetOutput(os.Stdout)
kingpin.Parse()
workerCount = (*nodesInCluster) * (*coresInNode) * (*smudgeFactor)
ranges = getTokenRanges()
shuffle(ranges)
hosts := strings.Split(*clusterHosts, ",")
cluster := gocql.NewCluster(hosts...)
cluster.Consistency = getConsistencyLevel(*clusterConsistency)
cluster.Timeout = time.Duration(*clusterTimeout * 1000 * 1000)
cluster.NumConns = *clusterNumConnections
cluster.CQLVersion = *clusterCQLVersion
cluster.PageSize = *clusterPageSize
cluster.Keyspace = *keyspace
if *userName != "" {
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: *userName,
Password: *password,
}
}
if *earliestLedgerIdx == 0 {
log.Println("Please specify ledger index to delete from")
return
}
runParameters := fmt.Sprintf(`
Execution Parameters:
=====================
Range to be deleted : %d -> latest
Scylla cluster nodes : %s
Keyspace : %s
Consistency : %s
Timeout (ms) : %d
Connections per host : %d
CQL Version : %s
Page size : %d
# of parallel threads : %d
# of ranges to be executed : %d
Skip deletion of:
- successor table : %t
- objects table : %t
- ledger_hashes table : %t
- transactions table : %t
- diff table : %t
- ledger_transactions table : %t
- ledgers table : %t
Will rite latest ledger : %t
`,
*earliestLedgerIdx,
*clusterHosts,
*keyspace,
*clusterConsistency,
cluster.Timeout/1000/1000,
*clusterNumConnections,
*clusterCQLVersion,
*clusterPageSize,
workerCount,
len(ranges),
*skipSuccessorTable,
*skipObjectsTable,
*skipLedgerHashesTable,
*skipTransactionsTable,
*skipDiffTable,
*skipLedgerTransactionsTable,
*skipLedgersTable,
!*skipWriteLatestLedger)
fmt.Println(runParameters)
log.Printf("Will delete everything after ledger index %d (exclusive) and till latest\n", *earliestLedgerIdx)
log.Println("WARNING: Please make sure that there are no Clio writers operating on the DB while this script is running")
log.Println("Are you sure you want to continue? (y/n)")
var continueFlag string
if fmt.Scanln(&continueFlag); continueFlag != "y" {
log.Println("Aborting...")
return
}
startTime := time.Now().UTC()
earliestLedgerIdxInDB, latestLedgerIdxInDB, err := getLedgerRange(cluster)
if err != nil {
log.Fatal(err)
}
if earliestLedgerIdxInDB > *earliestLedgerIdx {
log.Fatal("Earliest ledger index in DB is greater than the one specified. Aborting...")
}
if latestLedgerIdxInDB < *earliestLedgerIdx {
log.Fatal("Latest ledger index in DB is smaller than the one specified. Aborting...")
}
if err := deleteLedgerData(cluster, *earliestLedgerIdx+1, latestLedgerIdxInDB); err != nil {
log.Fatal(err)
}
fmt.Printf("Total Execution Time: %s\n\n", time.Since(startTime))
fmt.Println("NOTE: Cassandra/ScyllaDB only writes tombstones. You need to run compaction to free up disk space.")
}
func getLedgerRange(cluster *gocql.ClusterConfig) (uint64, uint64, error) {
var (
firstLedgerIdx uint64
latestLedgerIdx uint64
)
session, err := cluster.CreateSession()
if err != nil {
log.Fatal(err)
}
defer session.Close()
if err := session.Query("select sequence from ledger_range where is_latest = ?", false).Scan(&firstLedgerIdx); err != nil {
return 0, 0, err
}
if err := session.Query("select sequence from ledger_range where is_latest = ?", true).Scan(&latestLedgerIdx); err != nil {
return 0, 0, err
}
log.Printf("DB ledger range is %d:%d\n", firstLedgerIdx, latestLedgerIdx)
return firstLedgerIdx, latestLedgerIdx, nil
}
func deleteLedgerData(cluster *gocql.ClusterConfig, fromLedgerIdx uint64, toLedgerIdx uint64) error {
var totalErrors uint64
var totalRows uint64
var totalDeletes uint64
var info deleteInfo
var rowsCount uint64
var deleteCount uint64
var errCount uint64
log.Printf("Start scanning and removing data for %d -> latest (%d according to ledger_range table)\n\n", fromLedgerIdx, toLedgerIdx)
// successor queries
if !*skipSuccessorTable {
log.Println("Generating delete queries for successor table")
info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
"SELECT key, seq FROM successor WHERE token(key) >= ? AND token(key) <= ?",
"DELETE FROM successor WHERE key = ? AND seq = ?")
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// objects queries
if !*skipObjectsTable {
log.Println("Generating delete queries for objects table")
info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
"SELECT key, sequence FROM objects WHERE token(key) >= ? AND token(key) <= ?",
"DELETE FROM objects WHERE key = ? AND sequence = ?")
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// ledger_hashes queries
if !*skipLedgerHashesTable {
log.Println("Generating delete queries for ledger_hashes table")
info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
"SELECT hash, sequence FROM ledger_hashes WHERE token(hash) >= ? AND token(hash) <= ?",
"DELETE FROM ledger_hashes WHERE hash = ?")
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: false})
totalErrors += errCount
totalDeletes += deleteCount
}
// transactions queries
if !*skipTransactionsTable {
log.Println("Generating delete queries for transactions table")
info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
"SELECT hash, ledger_sequence FROM transactions WHERE token(hash) >= ? AND token(hash) <= ?",
"DELETE FROM transactions WHERE hash = ?")
log.Printf("Total delete queries: %d\n", len(info.Data))
log.Printf("Total traversed rows: %d\n\n", rowsCount)
totalErrors += errCount
totalRows += rowsCount
deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: false})
totalErrors += errCount
totalDeletes += deleteCount
}
// diff queries
if !*skipDiffTable {
log.Println("Generating delete queries for diff table")
info = prepareSimpleDeleteQueries(fromLedgerIdx, toLedgerIdx,
"DELETE FROM diff WHERE seq = ?")
log.Printf("Total delete queries: %d\n\n", len(info.Data))
deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// ledger_transactions queries
if !*skipLedgerTransactionsTable {
log.Println("Generating delete queries for ledger_transactions table")
info = prepareSimpleDeleteQueries(fromLedgerIdx, toLedgerIdx,
"DELETE FROM ledger_transactions WHERE ledger_sequence = ?")
log.Printf("Total delete queries: %d\n\n", len(info.Data))
deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: false, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// ledgers queries
if !*skipLedgersTable {
log.Println("Generating delete queries for ledgers table")
info = prepareSimpleDeleteQueries(fromLedgerIdx, toLedgerIdx,
"DELETE FROM ledgers WHERE sequence = ?")
log.Printf("Total delete queries: %d\n\n", len(info.Data))
deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: false, UseSeq: true})
totalErrors += errCount
totalDeletes += deleteCount
}
// TODO: tbd what to do with account_tx as it got tuple for seq_idx
// TODO: also, whether we need to take care of nft tables and other stuff like that
if !*skipWriteLatestLedger {
if err := updateLedgerRange(cluster, fromLedgerIdx-1); err != nil {
log.Printf("ERROR failed updating ledger range: %s\n", err)
return err
}
log.Printf("Updated latest ledger to %d in ledger_range table\n\n", fromLedgerIdx-1)
}
log.Printf("TOTAL ERRORS: %d\n", totalErrors)
log.Printf("TOTAL ROWS TRAVERSED: %d\n", totalRows)
log.Printf("TOTAL DELETES: %d\n\n", totalDeletes)
log.Printf("Completed deletion for %d -> %d\n\n", fromLedgerIdx, toLedgerIdx)
return nil
}
func prepareSimpleDeleteQueries(fromLedgerIdx uint64, toLedgerIdx uint64, deleteQueryTemplate string) deleteInfo {
var info = deleteInfo{Query: deleteQueryTemplate}
// Note: we deliberately add 1 extra ledger to make sure we delete any data Clio might have written
// if it crashed or was stopped in the middle of writing just before it wrote ledger_range.
for i := fromLedgerIdx; i <= toLedgerIdx+1; i++ {
info.Data = append(info.Data, deleteParams{Seq: i})
}
return info
}
func prepareDeleteQueries(cluster *gocql.ClusterConfig, fromLedgerIdx uint64, queryTemplate string, deleteQueryTemplate string) (deleteInfo, uint64, uint64) {
rangesChannel := make(chan *tokenRange, len(ranges))
for i := range ranges {
rangesChannel <- ranges[i]
}
close(rangesChannel)
outChannel := make(chan deleteParams)
var info = deleteInfo{Query: deleteQueryTemplate}
go func() {
for params := range outChannel {
info.Data = append(info.Data, params)
}
}()
var wg sync.WaitGroup
var sessionCreationWaitGroup sync.WaitGroup
var totalRows uint64
var totalErrors uint64
wg.Add(workerCount)
sessionCreationWaitGroup.Add(workerCount)
for i := 0; i < workerCount; i++ {
go func(q string) {
defer wg.Done()
var session *gocql.Session
var err error
if session, err = cluster.CreateSession(); err == nil {
defer session.Close()
sessionCreationWaitGroup.Done()
sessionCreationWaitGroup.Wait()
preparedQuery := session.Query(q)
for r := range rangesChannel {
preparedQuery.Bind(r.StartRange, r.EndRange)
var pageState []byte
var rowsRetrieved uint64
for {
iter := preparedQuery.PageSize(*clusterPageSize).PageState(pageState).Iter()
nextPageState := iter.PageState()
scanner := iter.Scanner()
for scanner.Next() {
var key []byte
var seq uint64
err = scanner.Scan(&key, &seq)
if err == nil {
rowsRetrieved++
// only grab the rows that are in the correct range of sequence numbers
if fromLedgerIdx <= seq {
outChannel <- deleteParams{Seq: seq, Blob: key}
}
} else {
log.Printf("ERROR: page iteration failed: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [from=%d][to=%d][pagestate=%x]", queryTemplate, r.StartRange, r.EndRange, pageState))
atomic.AddUint64(&totalErrors, 1)
}
}
if len(nextPageState) == 0 {
break
}
pageState = nextPageState
}
atomic.AddUint64(&totalRows, rowsRetrieved)
}
} else {
log.Printf("ERROR: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
atomic.AddUint64(&totalErrors, 1)
}
}(queryTemplate)
}
wg.Wait()
close(outChannel)
return info, totalRows, totalErrors
}
func performDeleteQueries(cluster *gocql.ClusterConfig, info *deleteInfo, colSettings columnSettings) (uint64, uint64) {
var wg sync.WaitGroup
var sessionCreationWaitGroup sync.WaitGroup
var totalDeletes uint64
var totalErrors uint64
chunks := splitDeleteWork(info)
chunksChannel := make(chan []deleteParams, len(chunks))
for i := range chunks {
chunksChannel <- chunks[i]
}
close(chunksChannel)
wg.Add(workerCount)
sessionCreationWaitGroup.Add(workerCount)
query := info.Query
bindCount := strings.Count(query, "?")
for i := 0; i < workerCount; i++ {
go func(number int, q string, bc int) {
defer wg.Done()
var session *gocql.Session
var err error
if session, err = cluster.CreateSession(); err == nil {
defer session.Close()
sessionCreationWaitGroup.Done()
sessionCreationWaitGroup.Wait()
preparedQuery := session.Query(q)
for chunk := range chunksChannel {
for _, r := range chunk {
if bc == 2 {
preparedQuery.Bind(r.Blob, r.Seq)
} else if bc == 1 {
if colSettings.UseSeq {
preparedQuery.Bind(r.Seq)
} else if colSettings.UseBlob {
preparedQuery.Bind(r.Blob)
}
}
if err := preparedQuery.Exec(); err != nil {
log.Printf("DELETE ERROR: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [blob=0x%x][seq=%d]", info.Query, r.Blob, r.Seq))
atomic.AddUint64(&totalErrors, 1)
} else {
atomic.AddUint64(&totalDeletes, 1)
}
}
}
} else {
log.Printf("ERROR: %s\n", err)
fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
atomic.AddUint64(&totalErrors, 1)
}
}(i, query, bindCount)
}
wg.Wait()
return totalDeletes, totalErrors
}
func updateLedgerRange(cluster *gocql.ClusterConfig, ledgerIndex uint64) error {
log.Printf("Updating latest ledger to %d\n", ledgerIndex)
if session, err := cluster.CreateSession(); err == nil {
defer session.Close()
query := "UPDATE ledger_range SET sequence = ? WHERE is_latest = ?"
preparedQuery := session.Query(query, ledgerIndex, true)
if err := preparedQuery.Exec(); err != nil {
fmt.Fprintf(os.Stderr, "FAILED QUERY: %s [seq=%d][true]\n", query, ledgerIndex)
return err
}
} else {
fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
return err
}
return nil
}