mirror of
https://github.com/XRPLF/clio.git
synced 2025-11-20 11:45:53 +00:00
fix: Data race in new webserver (#1926)
There was a data race inside `CoroutineGroup` because internal timer was used from multiple threads in the methods `asyncWait()` and `onCoroutineComplete()`. Changing `registerForeign()` to spawn to the same `yield_context` fixes the problem because now the timer is accessed only from the same coroutine which has an internal strand. During debugging I also added websocket support for `request_gun` tool.
This commit is contained in:
@@ -56,13 +56,15 @@ CoroutineGroup::spawn(boost::asio::yield_context yield, std::function<void(boost
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::function<void()>>
|
std::optional<std::function<void()>>
|
||||||
CoroutineGroup::registerForeign()
|
CoroutineGroup::registerForeign(boost::asio::yield_context yield)
|
||||||
{
|
{
|
||||||
if (isFull())
|
if (isFull())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
++childrenCounter_;
|
++childrenCounter_;
|
||||||
return [this]() { onCoroutineCompleted(); };
|
// It is important to spawn onCoroutineCompleted() to the same coroutine as will be calling asyncWait().
|
||||||
|
// timer_ here is not thread safe, so without spawn there could be a data race.
|
||||||
|
return [this, yield]() { boost::asio::spawn(yield, [this](auto&&) { onCoroutineCompleted(); }); };
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
|||||||
@@ -73,10 +73,11 @@ public:
|
|||||||
* @note A foreign coroutine is still counted as a child one, i.e. calling this method increases the size of the
|
* @note A foreign coroutine is still counted as a child one, i.e. calling this method increases the size of the
|
||||||
* group.
|
* group.
|
||||||
*
|
*
|
||||||
|
* @param yield The yield context owning the coroutine group.
|
||||||
* @return A callback to call on foreign coroutine completes or std::nullopt if the group is already full.
|
* @return A callback to call on foreign coroutine completes or std::nullopt if the group is already full.
|
||||||
*/
|
*/
|
||||||
std::optional<std::function<void()>>
|
std::optional<std::function<void()>>
|
||||||
registerForeign();
|
registerForeign(boost::asio::yield_context yield);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Wait for all the coroutines in the group to finish
|
* @brief Wait for all the coroutines in the group to finish
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ public:
|
|||||||
{
|
{
|
||||||
std::optional<Response> response;
|
std::optional<Response> response;
|
||||||
util::CoroutineGroup coroutineGroup{yield, 1};
|
util::CoroutineGroup coroutineGroup{yield, 1};
|
||||||
auto const onTaskComplete = coroutineGroup.registerForeign();
|
auto const onTaskComplete = coroutineGroup.registerForeign(yield);
|
||||||
ASSERT(onTaskComplete.has_value(), "Coroutine group can't be full");
|
ASSERT(onTaskComplete.has_value(), "Coroutine group can't be full");
|
||||||
|
|
||||||
bool const postSuccessful = rpcEngine_->post(
|
bool const postSuccessful = rpcEngine_->post(
|
||||||
@@ -127,7 +127,7 @@ public:
|
|||||||
&response,
|
&response,
|
||||||
&onTaskComplete = onTaskComplete.value(),
|
&onTaskComplete = onTaskComplete.value(),
|
||||||
&connectionMetadata,
|
&connectionMetadata,
|
||||||
subscriptionContext = std::move(subscriptionContext)](boost::asio::yield_context yield) mutable {
|
subscriptionContext = std::move(subscriptionContext)](boost::asio::yield_context innerYield) mutable {
|
||||||
try {
|
try {
|
||||||
auto parsedRequest = boost::json::parse(request.message()).as_object();
|
auto parsedRequest = boost::json::parse(request.message()).as_object();
|
||||||
LOG(perfLog_.debug()) << connectionMetadata.tag() << "Adding to work queue";
|
LOG(perfLog_.debug()) << connectionMetadata.tag() << "Adding to work queue";
|
||||||
@@ -136,7 +136,11 @@ public:
|
|||||||
parsedRequest[JS(params)] = boost::json::array({boost::json::object{}});
|
parsedRequest[JS(params)] = boost::json::array({boost::json::object{}});
|
||||||
|
|
||||||
response = handleRequest(
|
response = handleRequest(
|
||||||
yield, request, std::move(parsedRequest), connectionMetadata, std::move(subscriptionContext)
|
innerYield,
|
||||||
|
request,
|
||||||
|
std::move(parsedRequest),
|
||||||
|
connectionMetadata,
|
||||||
|
std::move(subscriptionContext)
|
||||||
);
|
);
|
||||||
} catch (boost::system::system_error const& ex) {
|
} catch (boost::system::system_error const& ex) {
|
||||||
// system_error thrown when json parsing failed
|
// system_error thrown when json parsing failed
|
||||||
|
|||||||
@@ -294,14 +294,13 @@ ConnectionHandler::sequentRequestResponseLoop(
|
|||||||
|
|
||||||
LOG(log_.trace()) << connection.tag() << "Processing sequentially";
|
LOG(log_.trace()) << connection.tag() << "Processing sequentially";
|
||||||
while (true) {
|
while (true) {
|
||||||
auto expectedRequest = connection.receive(yield);
|
auto const expectedRequest = connection.receive(yield);
|
||||||
if (not expectedRequest)
|
if (not expectedRequest)
|
||||||
return handleError(expectedRequest.error(), connection);
|
return handleError(expectedRequest.error(), connection);
|
||||||
|
|
||||||
LOG(log_.info()) << connection.tag() << "Received request from ip = " << connection.ip();
|
LOG(log_.info()) << connection.tag() << "Received request from ip = " << connection.ip();
|
||||||
|
|
||||||
auto maybeReturnValue =
|
auto maybeReturnValue = processRequest(connection, subscriptionContext, expectedRequest.value(), yield);
|
||||||
processRequest(connection, subscriptionContext, std::move(expectedRequest).value(), yield);
|
|
||||||
if (maybeReturnValue.has_value())
|
if (maybeReturnValue.has_value())
|
||||||
return maybeReturnValue.value();
|
return maybeReturnValue.value();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -178,10 +178,10 @@ TEST_F(CoroutineGroupTests, SpawnForeign)
|
|||||||
runSpawn([this](boost::asio::yield_context yield) {
|
runSpawn([this](boost::asio::yield_context yield) {
|
||||||
CoroutineGroup group{yield, 1};
|
CoroutineGroup group{yield, 1};
|
||||||
|
|
||||||
auto const onForeignComplete = group.registerForeign();
|
auto const onForeignComplete = group.registerForeign(yield);
|
||||||
[&]() { ASSERT_TRUE(onForeignComplete.has_value()); }();
|
[&]() { ASSERT_TRUE(onForeignComplete.has_value()); }();
|
||||||
|
|
||||||
[&]() { ASSERT_FALSE(group.registerForeign().has_value()); }();
|
[&]() { ASSERT_FALSE(group.registerForeign(yield).has_value()); }();
|
||||||
|
|
||||||
boost::asio::spawn(ctx_, [this, &onForeignComplete](boost::asio::yield_context innerYield) {
|
boost::asio::spawn(ctx_, [this, &onForeignComplete](boost::asio::yield_context innerYield) {
|
||||||
boost::asio::steady_timer timer{innerYield.get_executor(), std::chrono::milliseconds{2}};
|
boost::asio::steady_timer timer{innerYield.get_executor(), std::chrono::milliseconds{2}};
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
{ "method": "server_definitions", "params": [ {} ] }
|
{"command":"server_info","id":72421}
|
||||||
|
{"account":"rMHDmxRrBAMUgeAriKXXfRN1ZdBs6aoCNY","ledger_index":"486998","queue":false,"signer_lists":false,"command":"account_info","id":5732}
|
||||||
|
{"ledger_index":"current","taker":"rGJRHBNZZ5v6gyoxKjt6GtitdBsWtdD79s","taker_gets":{"currency":"GZX","issuer":"r9gTsUB4hBS13QbAUWwYZPykj3wbauf8GX"},"taker_pays":{"currency":"XRP"},"command":"book_offers","id":85757}
|
||||||
|
{"ledger_index":"current","account":"r9gTsUB4hBS13QbAUWwYZPykj3wbauf8GX","command":"account_objects","id":32672}
|
||||||
|
|||||||
@@ -2,4 +2,7 @@ module requests_gun
|
|||||||
|
|
||||||
go 1.22.2
|
go 1.22.2
|
||||||
|
|
||||||
require github.com/spf13/pflag v1.0.5 // indirect
|
require (
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
github.com/spf13/pflag v1.0.5
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package ammo_provider
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -12,17 +13,17 @@ type AmmoProvider struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ap *AmmoProvider) getIndex() uint64 {
|
func (ap *AmmoProvider) getIndex() uint64 {
|
||||||
if ap.currentBullet.Load() >= uint64(len(ap.ammo)) {
|
result := ap.currentBullet.Add(1)
|
||||||
ap.currentBullet.Store(1)
|
return result % uint64(len(ap.ammo))
|
||||||
return 0
|
|
||||||
}
|
|
||||||
result := ap.currentBullet.Load()
|
|
||||||
ap.currentBullet.Add(1)
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ap *AmmoProvider) GetBullet() string {
|
func (ap *AmmoProvider) GetBullet() string {
|
||||||
return ap.ammo[ap.getIndex()]
|
for {
|
||||||
|
res := ap.ammo[ap.getIndex()]
|
||||||
|
if !strings.HasPrefix(res, "#") {
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(reader io.Reader) *AmmoProvider {
|
func New(reader io.Reader) *AmmoProvider {
|
||||||
|
|||||||
@@ -7,21 +7,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type CliArgs struct {
|
type CliArgs struct {
|
||||||
Url string
|
Host string
|
||||||
Port uint
|
Port uint
|
||||||
TargetLoad uint
|
TargetLoad uint
|
||||||
Ammo string
|
Ammo string
|
||||||
PrintErrors bool
|
PrintErrors bool
|
||||||
Help bool
|
Help bool
|
||||||
|
Ws bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func Parse() (*CliArgs, error) {
|
func Parse() (*CliArgs, error) {
|
||||||
flag.Usage = PrintUsage
|
flag.Usage = PrintUsage
|
||||||
url := flag.StringP("url", "u", "localhost", "URL to send the request to")
|
host := flag.StringP("url", "u", "localhost", "URL to send the request to")
|
||||||
port := flag.UintP("port", "p", 51233, "Port to send the request to")
|
port := flag.UintP("port", "p", 51233, "Port to send the request to")
|
||||||
target_load := flag.UintP("load", "l", 100, "Target requests per second load")
|
target_load := flag.UintP("load", "l", 100, "Target requests per second load")
|
||||||
print_errors := flag.BoolP("print-errors", "e", false, "Print errors")
|
print_errors := flag.BoolP("print-errors", "e", false, "Print errors")
|
||||||
help := flag.BoolP("help", "h", false, "Print help message")
|
help := flag.BoolP("help", "h", false, "Print help message")
|
||||||
|
ws := flag.BoolP("ws", "w", false, "Use websocket")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
@@ -29,7 +31,7 @@ func Parse() (*CliArgs, error) {
|
|||||||
return nil, fmt.Errorf("No ammo file provided")
|
return nil, fmt.Errorf("No ammo file provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CliArgs{*url, *port, *target_load, flag.Arg(0), *print_errors, *help}, nil
|
return &CliArgs{*host, *port, *target_load, flag.Arg(0), *print_errors, *help, *ws}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrintUsage() {
|
func PrintUsage() {
|
||||||
|
|||||||
@@ -6,16 +6,22 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestMaker interface {
|
type RequestMaker interface {
|
||||||
MakeRequest(request string) (*ResponseData, error)
|
MakeRequest(request string) (*ResponseData, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WebSocketClient struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
}
|
||||||
|
|
||||||
type HttpRequestMaker struct {
|
type HttpRequestMaker struct {
|
||||||
url string
|
host string
|
||||||
transport *http.Transport
|
transport *http.Transport
|
||||||
client *http.Client
|
client *http.Client
|
||||||
}
|
}
|
||||||
@@ -32,7 +38,7 @@ type ResponseData struct {
|
|||||||
|
|
||||||
func (h *HttpRequestMaker) MakeRequest(request string) (*ResponseData, error) {
|
func (h *HttpRequestMaker) MakeRequest(request string) (*ResponseData, error) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
req, err := http.NewRequest("POST", h.url, strings.NewReader(request))
|
req, err := http.NewRequest("POST", h.host, strings.NewReader(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("Error creating request: " + err.Error())
|
return nil, errors.New("Error creating request: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -72,3 +78,50 @@ func NewHttp(host string, port uint) *HttpRequestMaker {
|
|||||||
|
|
||||||
return &HttpRequestMaker{host + ":" + fmt.Sprintf("%d", port), transport, client}
|
return &HttpRequestMaker{host + ":" + fmt.Sprintf("%d", port), transport, client}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewWebSocketClient(host string, port uint) (*WebSocketClient, error) {
|
||||||
|
var u url.URL
|
||||||
|
if !strings.HasPrefix(host, "ws://") && !strings.HasPrefix(host, "wss://") {
|
||||||
|
u = url.URL{Scheme: "ws", Host: host + ":" + fmt.Sprintf("%d", port), Path: "/"}
|
||||||
|
} else {
|
||||||
|
u = url.URL{Host: host + ":" + fmt.Sprintf("%d", port), Path: "/"}
|
||||||
|
}
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Error connecting to WebSocket: " + err.Error())
|
||||||
|
}
|
||||||
|
return &WebSocketClient{conn: conn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends a message to the WebSocket server
|
||||||
|
func (ws *WebSocketClient) SendMessage(message string) (*ResponseData, error) {
|
||||||
|
defer ws.conn.Close()
|
||||||
|
start := time.Now()
|
||||||
|
err := ws.conn.WriteMessage(websocket.TextMessage, []byte(message))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Error sending ws message: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg []byte
|
||||||
|
err = ws.conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Error setting timeout: " + err.Error())
|
||||||
|
}
|
||||||
|
_, msg, err = ws.conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Error reading message: " + err.Error())
|
||||||
|
}
|
||||||
|
requestDuration := time.Since(start)
|
||||||
|
ws.conn.Close()
|
||||||
|
|
||||||
|
var response JsonMap
|
||||||
|
err = json.Unmarshal(msg, &response)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Error unmarshaling message: " + err.Error())
|
||||||
|
}
|
||||||
|
return &ResponseData{response, StatusCode(200), "WS Ok", requestDuration}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebSocketClient) Close() error {
|
||||||
|
return ws.conn.Close()
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,10 +30,20 @@ func Fire(ammoProvider *ammo_provider.AmmoProvider, args *parse_args.CliArgs) {
|
|||||||
doShot := func() {
|
doShot := func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
bullet := ammoProvider.GetBullet()
|
bullet := ammoProvider.GetBullet()
|
||||||
requestMaker := request_maker.NewHttp(args.Url, args.Port)
|
if args.Ws {
|
||||||
|
wsClient, err := request_maker.NewWebSocketClient(args.Host, args.Port)
|
||||||
|
if err != nil {
|
||||||
|
statistics.add(nil, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
responseData, err := wsClient.SendMessage(bullet)
|
||||||
|
statistics.add(responseData, err)
|
||||||
|
} else {
|
||||||
|
requestMaker := request_maker.NewHttp(args.Host, args.Port)
|
||||||
responseData, err := requestMaker.MakeRequest(bullet)
|
responseData, err := requestMaker.MakeRequest(bullet)
|
||||||
statistics.add(responseData, err)
|
statistics.add(responseData, err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
secondStart := time.Now()
|
secondStart := time.Now()
|
||||||
requestsNumber := uint(0)
|
requestsNumber := uint(0)
|
||||||
@@ -74,11 +84,10 @@ func (s *statistics) add(response *request_maker.ResponseData, err error) {
|
|||||||
}
|
}
|
||||||
if response.StatusCode != 200 || response.Body["error"] != nil {
|
if response.StatusCode != 200 || response.Body["error"] != nil {
|
||||||
if s.printErrors {
|
if s.printErrors {
|
||||||
log.Print("Response contains error: ", response.StatusStr)
|
|
||||||
if response.Body["error"] != nil {
|
if response.Body["error"] != nil {
|
||||||
log.Println(" ", response.Body["error"])
|
log.Print("Response contains error: ", response.Body["error"])
|
||||||
} else {
|
} else {
|
||||||
log.Println()
|
log.Print("Got bad status: ", response.StatusCode, response.StatusStr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.counters.badReply.Add(1)
|
s.counters.badReply.Add(1)
|
||||||
|
|||||||
@@ -16,6 +16,11 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.Help {
|
||||||
|
parse_args.PrintUsage()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Print("Loading ammo... ")
|
fmt.Print("Loading ammo... ")
|
||||||
f, err := os.Open(args.Ammo)
|
f, err := os.Open(args.Ammo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user