refactor: change the return type of mulDiv to std::optional (#4243)

- Previously, mulDiv had `std::pair<bool, uint64_t>` as the output type.
  - This is an error-prone interface as it is easy to ignore when
    overflow occurs.
- Using a return type of `std::optional` should decrease the likelihood
  of ignoring overflow.
  - It also allows for the use of optional::value_or() as a way to
    explicitly recover from overflow.
- Include limits.h header file preprocessing directive in order to
  satisfy gcc's numeric_limits incomplete_type requirement.

Fix #3495

---------

Co-authored-by: John Freeman <jfreeman08@gmail.com>
This commit is contained in:
Chenna Keshava B S
2023-07-05 20:11:19 -07:00
committed by GitHub
parent 77dc63b549
commit c6fee28b92
12 changed files with 93 additions and 83 deletions

View File

@@ -409,7 +409,7 @@ template <
class Source2,
class Dest,
class = enable_muldiv_t<Source1, Source2, Dest>>
std::pair<bool, Dest>
std::optional<Dest>
mulDivU(Source1 value, Dest mul, Source2 div)
{
// Fees can never be negative in any context.
@@ -420,7 +420,7 @@ mulDivU(Source1 value, Dest mul, Source2 div)
assert(value.value() >= 0);
assert(mul.value() >= 0);
assert(div.value() >= 0);
return {false, Dest{0}};
return std::nullopt;
}
using desttype = typename Dest::value_type;
@@ -428,12 +428,12 @@ mulDivU(Source1 value, Dest mul, Source2 div)
// Shortcuts, since these happen a lot in the real world
if (value == div)
return {true, mul};
return mul;
if (mul.value() == div.value())
{
if (value.value() > max)
return {false, Dest{max}};
return {true, Dest{static_cast<desttype>(value.value())}};
return std::nullopt;
return Dest{static_cast<desttype>(value.value())};
}
using namespace boost::multiprecision;
@@ -447,9 +447,9 @@ mulDivU(Source1 value, Dest mul, Source2 div)
auto quotient = product / div.value();
if (quotient > max)
return {false, Dest{max}};
return std::nullopt;
return {true, Dest{static_cast<desttype>(quotient)}};
return Dest{static_cast<desttype>(quotient)};
}
} // namespace feeunit
@@ -464,7 +464,7 @@ template <
class Source2,
class Dest,
class = feeunit::enable_muldiv_t<Source1, Source2, Dest>>
std::pair<bool, Dest>
std::optional<Dest>
mulDiv(Source1 value, Dest mul, Source2 div)
{
return feeunit::mulDivU(value, mul, div);
@@ -475,7 +475,7 @@ template <
class Source2,
class Dest,
class = feeunit::enable_muldiv_commute_t<Source1, Source2, Dest>>
std::pair<bool, Dest>
std::optional<Dest>
mulDiv(Dest value, Source1 mul, Source2 div)
{
// Multiplication is commutative
@@ -483,7 +483,7 @@ mulDiv(Dest value, Source1 mul, Source2 div)
}
template <class Dest, class = feeunit::enable_muldiv_dest_t<Dest>>
std::pair<bool, Dest>
std::optional<Dest>
mulDiv(std::uint64_t value, Dest mul, std::uint64_t div)
{
// Give the scalars a non-tag so the
@@ -492,7 +492,7 @@ mulDiv(std::uint64_t value, Dest mul, std::uint64_t div)
}
template <class Dest, class = feeunit::enable_muldiv_dest_t<Dest>>
std::pair<bool, Dest>
std::optional<Dest>
mulDiv(Dest value, std::uint64_t mul, std::uint64_t div)
{
// Multiplication is commutative
@@ -503,20 +503,24 @@ template <
class Source1,
class Source2,
class = feeunit::enable_muldiv_sources_t<Source1, Source2>>
std::pair<bool, std::uint64_t>
std::optional<std::uint64_t>
mulDiv(Source1 value, std::uint64_t mul, Source2 div)
{
// Give the scalars a dimensionless unit so the
// unit-handling version gets called.
auto unitresult = feeunit::mulDivU(value, feeunit::scalar(mul), div);
return {unitresult.first, unitresult.second.value()};
if (!unitresult)
return std::nullopt;
return unitresult->value();
}
template <
class Source1,
class Source2,
class = feeunit::enable_muldiv_sources_t<Source1, Source2>>
std::pair<bool, std::uint64_t>
std::optional<std::uint64_t>
mulDiv(std::uint64_t value, Source1 mul, Source2 div)
{
// Multiplication is commutative

View File

@@ -17,15 +17,13 @@
*/
//==============================================================================
#include <ripple/basics/contract.h>
#include <ripple/basics/mulDiv.h>
#include <boost/multiprecision/cpp_int.hpp>
#include <limits>
#include <utility>
#include <optional>
namespace ripple {
std::pair<bool, std::uint64_t>
std::optional<std::uint64_t>
mulDiv(std::uint64_t value, std::uint64_t mul, std::uint64_t div)
{
using namespace boost::multiprecision;
@@ -35,12 +33,10 @@ mulDiv(std::uint64_t value, std::uint64_t mul, std::uint64_t div)
result /= div;
auto constexpr limit = std::numeric_limits<std::uint64_t>::max();
if (result > ripple::muldiv_max)
return std::nullopt;
if (result > limit)
return {false, limit};
return {true, static_cast<std::uint64_t>(result)};
return static_cast<std::uint64_t>(result);
}
} // namespace ripple

View File

@@ -21,9 +21,12 @@
#define RIPPLE_BASICS_MULDIV_H_INCLUDED
#include <cstdint>
#include <limits>
#include <optional>
#include <utility>
namespace ripple {
auto constexpr muldiv_max = std::numeric_limits<std::uint64_t>::max();
/** Return value*mul/div accurately.
Computes the result of the multiplication and division in
@@ -31,14 +34,11 @@ namespace ripple {
Throws:
None
Returns:
`std::pair`:
`first` is `false` if the calculation overflows,
`true` if the calculation is safe.
`second` is the result of the calculation if
`first` is `false`, max value of `uint64_t`
if `true`.
`std::optional`:
`std::nullopt` if the calculation overflows. Otherwise, `value * mul
/ div`.
*/
std::pair<bool, std::uint64_t>
std::optional<std::uint64_t>
mulDiv(std::uint64_t value, std::uint64_t mul, std::uint64_t div);
} // namespace ripple