Some missing lock logic.

This commit is contained in:
JoelKatz
2012-01-05 16:53:31 -08:00
parent a532bfb402
commit c1213b20c6
2 changed files with 51 additions and 35 deletions

View File

@@ -285,8 +285,9 @@ LocalAccountEntry::pointer LocalAccountFamily::get(int seq)
uint160 Wallet::findFamilySN(const std::string& shortName) uint160 Wallet::findFamilySN(const std::string& shortName)
{ // OPTIMIZEME { // OPTIMIZEME
for(std::map<uint160, LocalAccountFamily::pointer>::iterator it=families.begin(); boost::recursive_mutex::scoped_lock sl(mLock);
it!=families.end(); ++it) for(std::map<uint160, LocalAccountFamily::pointer>::iterator it=mFamilies.begin();
it!=mFamilies.end(); ++it)
{ {
if(it->second->getShortName()==shortName) if(it->second->getShortName()==shortName)
return it->first; return it->first;
@@ -382,15 +383,17 @@ CKey::pointer LocalAccount::getPrivateKey()
void Wallet::getFamilies(std::vector<uint160>& familyIDs) void Wallet::getFamilies(std::vector<uint160>& familyIDs)
{ {
familyIDs.reserve(families.size()); boost::recursive_mutex::scoped_lock sl(mLock);
for(std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.begin(); fit!=families.end(); ++fit) familyIDs.reserve(mFamilies.size());
for(std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.begin(); fit!=mFamilies.end(); ++fit)
familyIDs.push_back(fit->first); familyIDs.push_back(fit->first);
} }
bool Wallet::getFamilyInfo(const uint160& family, std::string& name, std::string& comment) bool Wallet::getFamilyInfo(const uint160& family, std::string& name, std::string& comment)
{ {
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(family); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit==families.end()) return false; std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(family);
if(fit==mFamilies.end()) return false;
assert(fit->second->getFamily()==family); assert(fit->second->getFamily()==family);
name=fit->second->getShortName(); name=fit->second->getShortName();
comment=fit->second->getComment(); comment=fit->second->getComment();
@@ -400,8 +403,9 @@ bool Wallet::getFamilyInfo(const uint160& family, std::string& name, std::string
bool Wallet::getFullFamilyInfo(const uint160& family, std::string& name, std::string& comment, bool Wallet::getFullFamilyInfo(const uint160& family, std::string& name, std::string& comment,
std::string& pubGen, bool& isLocked) std::string& pubGen, bool& isLocked)
{ {
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(family); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit==families.end()) return false; std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(family);
if(fit==mFamilies.end()) return false;
assert(fit->second->getFamily()==family); assert(fit->second->getFamily()==family);
name=fit->second->getShortName(); name=fit->second->getShortName();
comment=fit->second->getComment(); comment=fit->second->getComment();
@@ -451,38 +455,42 @@ void Wallet::load()
std::string Wallet::getPubGenHex(const uint160& famBase) std::string Wallet::getPubGenHex(const uint160& famBase)
{ {
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(famBase); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit==families.end()) return ""; std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(famBase);
if(fit==mFamilies.end()) return "";
assert(fit->second->getFamily()==famBase); assert(fit->second->getFamily()==famBase);
return fit->second->getPubGenHex(); return fit->second->getPubGenHex();
} }
std::string Wallet::getShortName(const uint160& famBase) std::string Wallet::getShortName(const uint160& famBase)
{ {
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(famBase); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit==families.end()) return ""; std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(famBase);
if(fit==mFamilies.end()) return "";
assert(fit->second->getFamily()==famBase); assert(fit->second->getFamily()==famBase);
return fit->second->getShortName(); return fit->second->getShortName();
} }
LocalAccount::pointer Wallet::getLocalAccount(const uint160& family, int seq) LocalAccount::pointer Wallet::getLocalAccount(const uint160& family, int seq)
{ {
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(family); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit==families.end()) return LocalAccount::pointer(); std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(family);
if(fit==mFamilies.end()) return LocalAccount::pointer();
uint160 acct=fit->second->getAccount(seq, true); uint160 acct=fit->second->getAccount(seq, true);
std::map<uint160, LocalAccount::pointer>::iterator ait=accounts.find(acct); std::map<uint160, LocalAccount::pointer>::iterator ait=mAccounts.find(acct);
if(ait!=accounts.end()) return ait->second; if(ait!=mAccounts.end()) return ait->second;
LocalAccount::pointer lac(new LocalAccount(fit->second, seq)); LocalAccount::pointer lac(new LocalAccount(fit->second, seq));
accounts.insert(std::make_pair(acct, lac)); mAccounts.insert(std::make_pair(acct, lac));
return lac; return lac;
} }
LocalAccount::pointer Wallet::getLocalAccount(const uint160& acctID) LocalAccount::pointer Wallet::getLocalAccount(const uint160& acctID)
{ {
std::map<uint160, LocalAccount::pointer>::iterator ait=accounts.find(acctID); boost::recursive_mutex::scoped_lock sl(mLock);
if(ait==accounts.end()) return LocalAccount::pointer(); std::map<uint160, LocalAccount::pointer>::iterator ait=mAccounts.find(acctID);
if(ait==mAccounts.end()) return LocalAccount::pointer();
return ait->second; return ait->second;
} }
@@ -521,21 +529,23 @@ LocalAccount::pointer Wallet::parseAccount(const std::string& specifier)
uint160 Wallet::peekKey(const uint160& family, int seq) uint160 Wallet::peekKey(const uint160& family, int seq)
{ {
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(family); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit==families.end()) return uint160(); std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(family);
if(fit==mFamilies.end()) return uint160();
return fit->second->getAccount(seq, false); return fit->second->getAccount(seq, false);
} }
void Wallet::delFamily(const uint160& familyName) void Wallet::delFamily(const uint160& familyName)
{ {
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(familyName); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit==families.end()) return; std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(familyName);
if(fit==mFamilies.end()) return;
std::map<int, LocalAccountEntry::pointer>& acctMap=fit->second->getAcctMap(); std::map<int, LocalAccountEntry::pointer>& acctMap=fit->second->getAcctMap();
for(std::map<int, LocalAccountEntry::pointer>::iterator it=acctMap.begin(); it!=acctMap.end(); ++it) for(std::map<int, LocalAccountEntry::pointer>::iterator it=acctMap.begin(); it!=acctMap.end(); ++it)
accounts.erase(it->second->getAccountID()); mAccounts.erase(it->second->getAccountID());
families.erase(familyName); mFamilies.erase(familyName);
} }
LocalAccountFamily::pointer Wallet::doPublic(const std::string& pubKey, bool do_create, bool do_db) LocalAccountFamily::pointer Wallet::doPublic(const std::string& pubKey, bool do_create, bool do_db)
@@ -550,8 +560,9 @@ LocalAccountFamily::pointer Wallet::doPublic(const std::string& pubKey, bool do_
while(rootPubKey.size()<33) rootPubKey.push_back((unsigned char)0); while(rootPubKey.size()<33) rootPubKey.push_back((unsigned char)0);
uint160 family=NewcoinAddress(rootPubKey).GetHash160(); uint160 family=NewcoinAddress(rootPubKey).GetHash160();
std::map<uint160, LocalAccountFamily::pointer>::iterator fit=families.find(family); boost::recursive_mutex::scoped_lock sl(mLock);
if(fit!=families.end()) // already added std::map<uint160, LocalAccountFamily::pointer>::iterator fit=mFamilies.find(family);
if(fit!=mFamilies.end()) // already added
{ {
EC_KEY_free(pkey); EC_KEY_free(pkey);
return fit->second; return fit->second;
@@ -572,11 +583,12 @@ LocalAccountFamily::pointer Wallet::doPublic(const std::string& pubKey, bool do_
{ {
fam=LocalAccountFamily::pointer(new LocalAccountFamily(family, fam=LocalAccountFamily::pointer(new LocalAccountFamily(family,
EC_KEY_get0_group(pkey), EC_KEY_get0_public_key(pkey))); EC_KEY_get0_group(pkey), EC_KEY_get0_public_key(pkey)));
families.insert(std::make_pair(family, fam)); mFamilies.insert(std::make_pair(family, fam));
if(do_db) fam->write(true); if(do_db) fam->write(true);
} }
sl.unlock();
EC_KEY_free(pkey); EC_KEY_free(pkey);
return fam; return fam;
} }
@@ -592,9 +604,10 @@ LocalAccountFamily::pointer Wallet::doPrivate(const uint256& passPhrase, bool do
while(rootPubKey.size()<33) rootPubKey.push_back((unsigned char)0); while(rootPubKey.size()<33) rootPubKey.push_back((unsigned char)0);
uint160 family=NewcoinAddress(rootPubKey).GetHash160(); uint160 family=NewcoinAddress(rootPubKey).GetHash160();
boost::recursive_mutex::scoped_lock sl(mLock);
LocalAccountFamily::pointer fam; LocalAccountFamily::pointer fam;
std::map<uint160, LocalAccountFamily::pointer>::iterator it=families.find(family); std::map<uint160, LocalAccountFamily::pointer>::iterator it=mFamilies.find(family);
if(it==families.end()) if(it==mFamilies.end())
{ // family not found { // family not found
fam=LocalAccountFamily::readFamily(family); fam=LocalAccountFamily::readFamily(family);
if(!fam) if(!fam)
@@ -606,7 +619,7 @@ LocalAccountFamily::pointer Wallet::doPrivate(const uint256& passPhrase, bool do
} }
fam=LocalAccountFamily::pointer(new LocalAccountFamily(family, fam=LocalAccountFamily::pointer(new LocalAccountFamily(family,
EC_KEY_get0_group(base), EC_KEY_get0_public_key(base))); EC_KEY_get0_group(base), EC_KEY_get0_public_key(base)));
families.insert(std::make_pair(family, fam)); mFamilies.insert(std::make_pair(family, fam));
fam->write(true); fam->write(true);
} }
} }
@@ -615,12 +628,13 @@ LocalAccountFamily::pointer Wallet::doPrivate(const uint256& passPhrase, bool do
if(do_unlock && fam->isLocked()) if(do_unlock && fam->isLocked())
fam->unlock(EC_KEY_get0_private_key(base)); fam->unlock(EC_KEY_get0_private_key(base));
sl.unlock();
EC_KEY_free(base); EC_KEY_free(base);
return fam; return fam;
} }
bool Wallet::unitTest() bool Wallet::unitTest()
{ // Create 100 keys for each of 1,000 families Ensure all keys match { // Create 100 keys for each of 1,000 families and ensure all keys match
Wallet pub, priv; Wallet pub, priv;
uint256 privBase(time(NULL)^(getpid()<<16)); uint256 privBase(time(NULL)^(getpid()<<16));

View File

@@ -5,6 +5,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <boost/thread/recursive_mutex.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include "openssl/ec.h" #include "openssl/ec.h"
@@ -137,8 +138,9 @@ public:
class Wallet class Wallet
{ {
protected: protected:
std::map<uint160, LocalAccountFamily::pointer> families; boost::recursive_mutex mLock;
std::map<uint160, LocalAccount::pointer> accounts; std::map<uint160, LocalAccountFamily::pointer> mFamilies;
std::map<uint160, LocalAccount::pointer> mAccounts;
LocalAccountFamily::pointer doPrivate(const uint256& passPhrase, bool do_create, bool do_unlock); LocalAccountFamily::pointer doPrivate(const uint256& passPhrase, bool do_create, bool do_unlock);
LocalAccountFamily::pointer doPublic(const std::string& pubKey, bool do_create, bool do_db); LocalAccountFamily::pointer doPublic(const std::string& pubKey, bool do_create, bool do_db);