diff --git a/src/util/CoroutineGroup.cpp b/src/util/CoroutineGroup.cpp index 195df2f3..b288859b 100644 --- a/src/util/CoroutineGroup.cpp +++ b/src/util/CoroutineGroup.cpp @@ -56,13 +56,15 @@ CoroutineGroup::spawn(boost::asio::yield_context yield, std::function> -CoroutineGroup::registerForeign() +CoroutineGroup::registerForeign(boost::asio::yield_context yield) { if (isFull()) return std::nullopt; ++childrenCounter_; - return [this]() { onCoroutineCompleted(); }; + // It is important to spawn onCoroutineCompleted() to the same coroutine as will be calling asyncWait(). + // timer_ here is not thread safe, so without spawn there could be a data race. + return [this, yield]() { boost::asio::spawn(yield, [this](auto&&) { onCoroutineCompleted(); }); }; } void diff --git a/src/util/CoroutineGroup.hpp b/src/util/CoroutineGroup.hpp index 9e5b70b1..b14c3b0f 100644 --- a/src/util/CoroutineGroup.hpp +++ b/src/util/CoroutineGroup.hpp @@ -73,10 +73,11 @@ public: * @note A foreign coroutine is still counted as a child one, i.e. calling this method increases the size of the * group. * + * @param yield The yield context owning the coroutine group. * @return A callback to call on foreign coroutine completes or std::nullopt if the group is already full. */ std::optional> - registerForeign(); + registerForeign(boost::asio::yield_context yield); /** * @brief Wait for all the coroutines in the group to finish diff --git a/src/web/ng/RPCServerHandler.hpp b/src/web/ng/RPCServerHandler.hpp index 5865748f..3453489d 100644 --- a/src/web/ng/RPCServerHandler.hpp +++ b/src/web/ng/RPCServerHandler.hpp @@ -118,7 +118,7 @@ public: { std::optional response; util::CoroutineGroup coroutineGroup{yield, 1}; - auto const onTaskComplete = coroutineGroup.registerForeign(); + auto const onTaskComplete = coroutineGroup.registerForeign(yield); ASSERT(onTaskComplete.has_value(), "Coroutine group can't be full"); bool const postSuccessful = rpcEngine_->post( @@ -127,7 +127,7 @@ public: &response, &onTaskComplete = onTaskComplete.value(), &connectionMetadata, - subscriptionContext = std::move(subscriptionContext)](boost::asio::yield_context yield) mutable { + subscriptionContext = std::move(subscriptionContext)](boost::asio::yield_context innerYield) mutable { try { auto parsedRequest = boost::json::parse(request.message()).as_object(); LOG(perfLog_.debug()) << connectionMetadata.tag() << "Adding to work queue"; @@ -136,7 +136,11 @@ public: parsedRequest[JS(params)] = boost::json::array({boost::json::object{}}); response = handleRequest( - yield, request, std::move(parsedRequest), connectionMetadata, std::move(subscriptionContext) + innerYield, + request, + std::move(parsedRequest), + connectionMetadata, + std::move(subscriptionContext) ); } catch (boost::system::system_error const& ex) { // system_error thrown when json parsing failed diff --git a/src/web/ng/impl/ConnectionHandler.cpp b/src/web/ng/impl/ConnectionHandler.cpp index 0f2cd50d..adfd8bc1 100644 --- a/src/web/ng/impl/ConnectionHandler.cpp +++ b/src/web/ng/impl/ConnectionHandler.cpp @@ -294,14 +294,13 @@ ConnectionHandler::sequentRequestResponseLoop( LOG(log_.trace()) << connection.tag() << "Processing sequentially"; while (true) { - auto expectedRequest = connection.receive(yield); + auto const expectedRequest = connection.receive(yield); if (not expectedRequest) return handleError(expectedRequest.error(), connection); LOG(log_.info()) << connection.tag() << "Received request from ip = " << connection.ip(); - auto maybeReturnValue = - processRequest(connection, subscriptionContext, std::move(expectedRequest).value(), yield); + auto maybeReturnValue = processRequest(connection, subscriptionContext, expectedRequest.value(), yield); if (maybeReturnValue.has_value()) return maybeReturnValue.value(); } diff --git a/tests/unit/util/CoroutineGroupTests.cpp b/tests/unit/util/CoroutineGroupTests.cpp index 1136293f..3feeb50b 100644 --- a/tests/unit/util/CoroutineGroupTests.cpp +++ b/tests/unit/util/CoroutineGroupTests.cpp @@ -178,10 +178,10 @@ TEST_F(CoroutineGroupTests, SpawnForeign) runSpawn([this](boost::asio::yield_context yield) { CoroutineGroup group{yield, 1}; - auto const onForeignComplete = group.registerForeign(); + auto const onForeignComplete = group.registerForeign(yield); [&]() { ASSERT_TRUE(onForeignComplete.has_value()); }(); - [&]() { ASSERT_FALSE(group.registerForeign().has_value()); }(); + [&]() { ASSERT_FALSE(group.registerForeign(yield).has_value()); }(); boost::asio::spawn(ctx_, [this, &onForeignComplete](boost::asio::yield_context innerYield) { boost::asio::steady_timer timer{innerYield.get_executor(), std::chrono::milliseconds{2}}; diff --git a/tools/requests_gun/ammo.txt b/tools/requests_gun/ammo.txt index b724bfa3..d7c12dad 100644 --- a/tools/requests_gun/ammo.txt +++ b/tools/requests_gun/ammo.txt @@ -1 +1,4 @@ -{ "method": "server_definitions", "params": [ {} ] } +{"command":"server_info","id":72421} +{"account":"rMHDmxRrBAMUgeAriKXXfRN1ZdBs6aoCNY","ledger_index":"486998","queue":false,"signer_lists":false,"command":"account_info","id":5732} +{"ledger_index":"current","taker":"rGJRHBNZZ5v6gyoxKjt6GtitdBsWtdD79s","taker_gets":{"currency":"GZX","issuer":"r9gTsUB4hBS13QbAUWwYZPykj3wbauf8GX"},"taker_pays":{"currency":"XRP"},"command":"book_offers","id":85757} +{"ledger_index":"current","account":"r9gTsUB4hBS13QbAUWwYZPykj3wbauf8GX","command":"account_objects","id":32672} diff --git a/tools/requests_gun/go.mod b/tools/requests_gun/go.mod index e87f78dc..e7319361 100644 --- a/tools/requests_gun/go.mod +++ b/tools/requests_gun/go.mod @@ -2,4 +2,7 @@ module requests_gun go 1.22.2 -require github.com/spf13/pflag v1.0.5 // indirect +require ( + github.com/gorilla/websocket v1.5.3 + github.com/spf13/pflag v1.0.5 +) diff --git a/tools/requests_gun/go.sum b/tools/requests_gun/go.sum index 287f6fa8..61a46119 100644 --- a/tools/requests_gun/go.sum +++ b/tools/requests_gun/go.sum @@ -1,2 +1,4 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= diff --git a/tools/requests_gun/internal/ammo_provider/ammo_provider.go b/tools/requests_gun/internal/ammo_provider/ammo_provider.go index 00543198..ad664cec 100644 --- a/tools/requests_gun/internal/ammo_provider/ammo_provider.go +++ b/tools/requests_gun/internal/ammo_provider/ammo_provider.go @@ -3,6 +3,7 @@ package ammo_provider import ( "bufio" "io" + "strings" "sync/atomic" ) @@ -12,17 +13,17 @@ type AmmoProvider struct { } func (ap *AmmoProvider) getIndex() uint64 { - if ap.currentBullet.Load() >= uint64(len(ap.ammo)) { - ap.currentBullet.Store(1) - return 0 - } - result := ap.currentBullet.Load() - ap.currentBullet.Add(1) - return result + result := ap.currentBullet.Add(1) + return result % uint64(len(ap.ammo)) } func (ap *AmmoProvider) GetBullet() string { - return ap.ammo[ap.getIndex()] + for { + res := ap.ammo[ap.getIndex()] + if !strings.HasPrefix(res, "#") { + return res + } + } } func New(reader io.Reader) *AmmoProvider { diff --git a/tools/requests_gun/internal/parse_args/parse_args.go b/tools/requests_gun/internal/parse_args/parse_args.go index 1f1a8971..7b937589 100644 --- a/tools/requests_gun/internal/parse_args/parse_args.go +++ b/tools/requests_gun/internal/parse_args/parse_args.go @@ -7,21 +7,23 @@ import ( ) type CliArgs struct { - Url string + Host string Port uint TargetLoad uint Ammo string PrintErrors bool Help bool + Ws bool } func Parse() (*CliArgs, error) { flag.Usage = PrintUsage - url := flag.StringP("url", "u", "localhost", "URL to send the request to") + host := flag.StringP("url", "u", "localhost", "URL to send the request to") port := flag.UintP("port", "p", 51233, "Port to send the request to") target_load := flag.UintP("load", "l", 100, "Target requests per second load") print_errors := flag.BoolP("print-errors", "e", false, "Print errors") help := flag.BoolP("help", "h", false, "Print help message") + ws := flag.BoolP("ws", "w", false, "Use websocket") flag.Parse() @@ -29,7 +31,7 @@ func Parse() (*CliArgs, error) { return nil, fmt.Errorf("No ammo file provided") } - return &CliArgs{*url, *port, *target_load, flag.Arg(0), *print_errors, *help}, nil + return &CliArgs{*host, *port, *target_load, flag.Arg(0), *print_errors, *help, *ws}, nil } func PrintUsage() { diff --git a/tools/requests_gun/internal/request_maker/request_maker.go b/tools/requests_gun/internal/request_maker/request_maker.go index dd2c55ad..670bc1fa 100644 --- a/tools/requests_gun/internal/request_maker/request_maker.go +++ b/tools/requests_gun/internal/request_maker/request_maker.go @@ -6,16 +6,22 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "time" + "github.com/gorilla/websocket" ) type RequestMaker interface { MakeRequest(request string) (*ResponseData, error) } +type WebSocketClient struct { + conn *websocket.Conn +} + type HttpRequestMaker struct { - url string + host string transport *http.Transport client *http.Client } @@ -32,7 +38,7 @@ type ResponseData struct { func (h *HttpRequestMaker) MakeRequest(request string) (*ResponseData, error) { startTime := time.Now() - req, err := http.NewRequest("POST", h.url, strings.NewReader(request)) + req, err := http.NewRequest("POST", h.host, strings.NewReader(request)) if err != nil { return nil, errors.New("Error creating request: " + err.Error()) } @@ -72,3 +78,50 @@ func NewHttp(host string, port uint) *HttpRequestMaker { return &HttpRequestMaker{host + ":" + fmt.Sprintf("%d", port), transport, client} } + +func NewWebSocketClient(host string, port uint) (*WebSocketClient, error) { + var u url.URL + if !strings.HasPrefix(host, "ws://") && !strings.HasPrefix(host, "wss://") { + u = url.URL{Scheme: "ws", Host: host + ":" + fmt.Sprintf("%d", port), Path: "/"} + } else { + u = url.URL{Host: host + ":" + fmt.Sprintf("%d", port), Path: "/"} + } + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + return nil, errors.New("Error connecting to WebSocket: " + err.Error()) + } + return &WebSocketClient{conn: conn}, nil +} + +// SendMessage sends a message to the WebSocket server +func (ws *WebSocketClient) SendMessage(message string) (*ResponseData, error) { + defer ws.conn.Close() + start := time.Now() + err := ws.conn.WriteMessage(websocket.TextMessage, []byte(message)) + if err != nil { + return nil, errors.New("Error sending ws message: " + err.Error()) + } + + var msg []byte + err = ws.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return nil, errors.New("Error setting timeout: " + err.Error()) + } + _, msg, err = ws.conn.ReadMessage() + if err != nil { + return nil, errors.New("Error reading message: " + err.Error()) + } + requestDuration := time.Since(start) + ws.conn.Close() + + var response JsonMap + err = json.Unmarshal(msg, &response) + if err != nil { + return nil, errors.New("Error unmarshaling message: " + err.Error()) + } + return &ResponseData{response, StatusCode(200), "WS Ok", requestDuration}, nil +} + +func (ws *WebSocketClient) Close() error { + return ws.conn.Close() +} diff --git a/tools/requests_gun/internal/trigger/trigger.go b/tools/requests_gun/internal/trigger/trigger.go index 7d71d21c..d4df0d30 100644 --- a/tools/requests_gun/internal/trigger/trigger.go +++ b/tools/requests_gun/internal/trigger/trigger.go @@ -30,9 +30,19 @@ func Fire(ammoProvider *ammo_provider.AmmoProvider, args *parse_args.CliArgs) { doShot := func() { defer wg.Done() bullet := ammoProvider.GetBullet() - requestMaker := request_maker.NewHttp(args.Url, args.Port) - responseData, err := requestMaker.MakeRequest(bullet) - statistics.add(responseData, err) + if args.Ws { + wsClient, err := request_maker.NewWebSocketClient(args.Host, args.Port) + if err != nil { + statistics.add(nil, err) + return + } + responseData, err := wsClient.SendMessage(bullet) + statistics.add(responseData, err) + } else { + requestMaker := request_maker.NewHttp(args.Host, args.Port) + responseData, err := requestMaker.MakeRequest(bullet) + statistics.add(responseData, err) + } } secondStart := time.Now() @@ -74,11 +84,10 @@ func (s *statistics) add(response *request_maker.ResponseData, err error) { } if response.StatusCode != 200 || response.Body["error"] != nil { if s.printErrors { - log.Print("Response contains error: ", response.StatusStr) if response.Body["error"] != nil { - log.Println(" ", response.Body["error"]) + log.Print("Response contains error: ", response.Body["error"]) } else { - log.Println() + log.Print("Got bad status: ", response.StatusCode, response.StatusStr) } } s.counters.badReply.Add(1) diff --git a/tools/requests_gun/requests_gun.go b/tools/requests_gun/requests_gun.go index b6189f19..04b895d7 100644 --- a/tools/requests_gun/requests_gun.go +++ b/tools/requests_gun/requests_gun.go @@ -16,6 +16,11 @@ func main() { os.Exit(1) } + if args.Help { + parse_args.PrintUsage() + os.Exit(0) + } + fmt.Print("Loading ammo... ") f, err := os.Open(args.Ammo) if err != nil {