Add safe_cast (RIPD-1702):

This change ensures that no overflow can occur when casting
between enums and integral types.
This commit is contained in:
Howard Hinnant
2018-12-21 17:13:58 -05:00
committed by Nik Bougalis
parent 494724578a
commit 148bbf4e8f
35 changed files with 213 additions and 86 deletions

View File

@@ -20,6 +20,7 @@
#ifndef RIPPLE_PROTOCOL_SFIELD_H_INCLUDED
#define RIPPLE_PROTOCOL_SFIELD_H_INCLUDED
#include <ripple/basics/safe_cast.h>
#include <ripple/json/json_value.h>
#include <cstdint>
#include <utility>
@@ -86,7 +87,7 @@ inline
int
field_code(SerializedTypeID id, int index)
{
return (static_cast<int>(id) << 16) | index;
return (safe_cast<int>(id) << 16) | index;
}
// constexpr

View File

@@ -20,12 +20,13 @@
#ifndef RIPPLE_PROTOCOL_SERIALIZER_H_INCLUDED
#define RIPPLE_PROTOCOL_SERIALIZER_H_INCLUDED
#include <ripple/protocol/SField.h>
#include <ripple/basics/base_uint.h>
#include <ripple/basics/contract.h>
#include <ripple/basics/Buffer.h>
#include <ripple/basics/safe_cast.h>
#include <ripple/basics/Slice.h>
#include <ripple/beast/crypto/secure_erase.h>
#include <ripple/protocol/SField.h>
#include <cassert>
#include <cstdint>
#include <iomanip>
@@ -158,7 +159,7 @@ public:
int addFieldID (int type, int name);
int addFieldID (SerializedTypeID type, int name)
{
return addFieldID (static_cast<int> (type), name);
return addFieldID (safe_cast<int> (type), name);
}
// DEPRECATED
@@ -266,7 +267,7 @@ public:
std::setw (2) <<
std::hex <<
std::setfill ('0') <<
static_cast<unsigned int>(element);
safe_cast<unsigned int>(element);
}
return h.str ();
}

View File

@@ -20,6 +20,7 @@
#ifndef RIPPLE_PROTOCOL_TER_H_INCLUDED
#define RIPPLE_PROTOCOL_TER_H_INCLUDED
#include <ripple/basics/safe_cast.h>
#include <ripple/json/json_value.h>
#include <boost/optional.hpp>
@@ -270,22 +271,22 @@ enum TECcodes : TERUnderlyingType
// For generic purposes, a free function that returns the value of a TE*codes.
constexpr TERUnderlyingType TERtoInt (TELcodes v)
{ return static_cast<TERUnderlyingType>(v); }
{ return safe_cast<TERUnderlyingType>(v); }
constexpr TERUnderlyingType TERtoInt (TEMcodes v)
{ return static_cast<TERUnderlyingType>(v); }
{ return safe_cast<TERUnderlyingType>(v); }
constexpr TERUnderlyingType TERtoInt (TEFcodes v)
{ return static_cast<TERUnderlyingType>(v); }
{ return safe_cast<TERUnderlyingType>(v); }
constexpr TERUnderlyingType TERtoInt (TERcodes v)
{ return static_cast<TERUnderlyingType>(v); }
{ return safe_cast<TERUnderlyingType>(v); }
constexpr TERUnderlyingType TERtoInt (TEScodes v)
{ return static_cast<TERUnderlyingType>(v); }
{ return safe_cast<TERUnderlyingType>(v); }
constexpr TERUnderlyingType TERtoInt (TECcodes v)
{ return static_cast<TERUnderlyingType>(v); }
{ return safe_cast<TERUnderlyingType>(v); }
//------------------------------------------------------------------------------
// Template class that is specific to selected ranges of error codes. The

View File

@@ -17,6 +17,7 @@
*/
//==============================================================================
#include <ripple/basics/safe_cast.h>
#include <ripple/protocol/SField.h>
#include <cassert>
#include <map>
@@ -358,7 +359,7 @@ SField::getField (int code)
if (it != unknownCodeToField.end ())
return * (it->second);
return *(unknownCodeToField[code] = std::unique_ptr<SField const>(
new SField(static_cast<SerializedTypeID>(type), field)));
new SField(safe_cast<SerializedTypeID>(type), field)));
}
}
@@ -385,7 +386,7 @@ std::string SField::getName () const
if (fieldValue == 0)
return "";
return std::to_string(static_cast<int> (fieldType)) + "/" +
return std::to_string(safe_cast<int> (fieldType)) + "/" +
std::to_string(fieldValue);
}

View File

@@ -20,6 +20,7 @@
#include <ripple/basics/contract.h>
#include <ripple/basics/Log.h>
#include <ripple/basics/safe_cast.h>
#include <ripple/protocol/JsonFields.h>
#include <ripple/protocol/SystemParameters.h>
#include <ripple/protocol/STAmount.h>
@@ -246,13 +247,13 @@ STAmount::STAmount (Issue const& issue,
STAmount::STAmount (Issue const& issue,
std::uint32_t mantissa, int exponent, bool negative)
: STAmount (issue, static_cast<std::uint64_t>(mantissa), exponent, negative)
: STAmount (issue, safe_cast<std::uint64_t>(mantissa), exponent, negative)
{
}
STAmount::STAmount (Issue const& issue,
int mantissa, int exponent)
: STAmount (issue, static_cast<std::int64_t>(mantissa), exponent)
: STAmount (issue, safe_cast<std::int64_t>(mantissa), exponent)
{
}

View File

@@ -19,6 +19,7 @@
#include <ripple/basics/Log.h>
#include <ripple/basics/StringUtilities.h>
#include <ripple/basics/safe_cast.h>
#include <ripple/protocol/LedgerFormats.h>
#include <ripple/protocol/STInteger.h>
#include <ripple/protocol/TxFormats.h>
@@ -98,7 +99,7 @@ STUInt16::getText () const
if (getFName () == sfLedgerEntryType)
{
auto item = LedgerFormats::getInstance ().findByType (
static_cast <LedgerEntryType> (value_));
safe_cast<LedgerEntryType> (value_));
if (item != nullptr)
return item->getName ();
@@ -107,7 +108,7 @@ STUInt16::getText () const
if (getFName () == sfTransactionType)
{
auto item =TxFormats::getInstance().findByType (
static_cast <TxType> (value_));
safe_cast<TxType> (value_));
if (item != nullptr)
return item->getName ();
@@ -123,7 +124,7 @@ STUInt16::getJson (int) const
if (getFName () == sfLedgerEntryType)
{
auto item = LedgerFormats::getInstance ().findByType (
static_cast <LedgerEntryType> (value_));
safe_cast<LedgerEntryType> (value_));
if (item != nullptr)
return item->getName ();
@@ -132,7 +133,7 @@ STUInt16::getJson (int) const
if (getFName () == sfTransactionType)
{
auto item = TxFormats::getInstance().findByType (
static_cast <TxType> (value_));
safe_cast<TxType> (value_));
if (item != nullptr)
return item->getName ();

View File

@@ -19,6 +19,7 @@
#include <ripple/basics/contract.h>
#include <ripple/basics/Log.h>
#include <ripple/basics/safe_cast.h>
#include <ripple/json/to_string.h>
#include <ripple/protocol/Indexes.h>
#include <ripple/protocol/JsonFields.h>
@@ -32,6 +33,10 @@ STLedgerEntry::STLedgerEntry (Keylet const& k)
, key_ (k.key)
, type_ (k.type)
{
if (!(0u <= type_ &&
type_ <= std::min<unsigned>(std::numeric_limits<std::uint16_t>::max(),
std::numeric_limits<std::underlying_type_t<LedgerEntryType>>::max())))
Throw<std::runtime_error> ("invalid ledger entry type: out of range");
auto const format =
LedgerFormats::getInstance().findByType (type_);
@@ -66,7 +71,7 @@ STLedgerEntry::STLedgerEntry (
void STLedgerEntry::setSLEType ()
{
auto format = LedgerFormats::getInstance().findByType (
static_cast <LedgerEntryType> (
safe_cast <LedgerEntryType> (
getFieldU16 (sfLedgerEntryType)));
if (format == nullptr)

View File

@@ -18,6 +18,7 @@
//==============================================================================
#include <ripple/basics/contract.h>
#include <ripple/basics/safe_cast.h>
#include <ripple/basics/StringUtilities.h>
#include <ripple/protocol/ErrorCodes.h>
#include <ripple/protocol/LedgerFormats.h>
@@ -275,6 +276,9 @@ static boost::optional<detail::STVar> parseLeaf (
TxType const txType (TxFormats::getInstance().
findTypeByName (strValue));
if (txType == ttINVALID)
Throw<std::runtime_error>(
"Invalid transaction format name");
ret = detail::make_stvar <STUInt16> (field,
static_cast <std::uint16_t> (txType));
@@ -287,6 +291,13 @@ static boost::optional<detail::STVar> parseLeaf (
LedgerFormats::getInstance().
findTypeByName (strValue));
if (!(0u <= type &&
type <= std::min<unsigned>(
std::numeric_limits<std::uint16_t>::max(),
std::numeric_limits<std::underlying_type_t
<LedgerEntryType>>::max())))
Throw<std::runtime_error>(
"Invalid ledger entry type: out of range");
ret = detail::make_stvar <STUInt16> (field,
static_cast <std::uint16_t> (type));
@@ -346,7 +357,7 @@ static boost::optional<detail::STVar> parseLeaf (
else if (value.isUInt ())
{
ret = detail::make_stvar <STUInt32> (field,
static_cast <std::uint32_t> (value.asUInt ()));
safe_cast <std::uint32_t> (value.asUInt ()));
}
else
{
@@ -378,7 +389,7 @@ static boost::optional<detail::STVar> parseLeaf (
else if (value.isUInt ())
{
ret = detail::make_stvar <STUInt64> (field,
static_cast <std::uint64_t> (value.asUInt ()));
safe_cast <std::uint64_t> (value.asUInt ()));
}
else
{

View File

@@ -17,6 +17,10 @@
*/
//==============================================================================
#include <ripple/basics/contract.h>
#include <ripple/basics/Log.h>
#include <ripple/basics/safe_cast.h>
#include <ripple/basics/StringUtilities.h>
#include <ripple/protocol/STTx.h>
#include <ripple/protocol/HashPrefix.h>
#include <ripple/protocol/JsonFields.h>
@@ -27,9 +31,6 @@
#include <ripple/protocol/STArray.h>
#include <ripple/protocol/TxFlags.h>
#include <ripple/protocol/UintTypes.h>
#include <ripple/basics/contract.h>
#include <ripple/basics/Log.h>
#include <ripple/basics/StringUtilities.h>
#include <ripple/json/to_string.h>
#include <boost/format.hpp>
#include <array>
@@ -49,7 +50,7 @@ auto getTxFormat (TxType type)
Throw<std::runtime_error> (
"Invalid transaction type " +
std::to_string (
static_cast<std::underlying_type_t<TxType>>(type)));
safe_cast<std::underlying_type_t<TxType>>(type)));
}
return format;
@@ -58,7 +59,7 @@ auto getTxFormat (TxType type)
STTx::STTx (STObject&& object) noexcept (false)
: STObject (std::move (object))
{
tx_type_ = static_cast <TxType> (getFieldU16 (sfTransactionType));
tx_type_ = safe_cast<TxType> (getFieldU16 (sfTransactionType));
applyTemplate (getTxFormat (tx_type_)->elements); // may throw
tid_ = getHash(HashPrefix::transactionID);
}
@@ -74,7 +75,7 @@ STTx::STTx (SerialIter& sit) noexcept (false)
if (set (sit))
Throw<std::runtime_error> ("Transaction contains an object terminator");
tx_type_ = static_cast<TxType> (getFieldU16 (sfTransactionType));
tx_type_ = safe_cast<TxType> (getFieldU16 (sfTransactionType));
applyTemplate (getTxFormat (tx_type_)->elements); // May throw
tid_ = getHash(HashPrefix::transactionID);
@@ -92,7 +93,7 @@ STTx::STTx (
assembler (*this);
tx_type_ = static_cast<TxType>(getFieldU16 (sfTransactionType));
tx_type_ = safe_cast<TxType>(getFieldU16 (sfTransactionType));
if (tx_type_ != type)
LogicError ("Transaction type was mutated during assembly");
@@ -523,7 +524,7 @@ isPseudoTx(STObject const& tx)
auto t = tx[~sfTransactionType];
if (!t)
return false;
auto tt = static_cast<TxType>(*t);
auto tt = safe_cast<TxType>(*t);
return tt == ttAMENDMENT || tt == ttFEE;
}

View File

@@ -17,6 +17,7 @@
*/
//==============================================================================
#include <ripple/basics/safe_cast.h>
#include <ripple/protocol/tokens.h>
#include <ripple/protocol/digest.h>
#include <boost/container/small_vector.hpp>
@@ -166,7 +167,7 @@ encodeToken (TokenType type,
// Lay the data out as
// <type><token><checksum>
buf[0] = static_cast<std::underlying_type_t <TokenType>>(type);
buf[0] = safe_cast<std::underlying_type_t <TokenType>>(type);
if (size)
std::memcpy(buf.data() + 1, token, size);
checksum(buf.data() + 1 + size, buf.data(), 1 + size);
@@ -259,14 +260,14 @@ std::string
decodeBase58Token (std::string const& s,
TokenType type, InverseArray const& inv)
{
auto ret = decodeBase58(s, inv);
std::string const ret = decodeBase58(s, inv);
// Reject zero length tokens
if (ret.size() < 6)
return {};
// The type must match.
if (type != static_cast<TokenType>(ret[0]))
if (type != safe_cast<TokenType>(static_cast<std::uint8_t>(ret[0])))
return {};
// And the checksum must as well.