diff --git a/src/comm/comm_session.cpp b/src/comm/comm_session.cpp index afec3d35..d01889cc 100644 --- a/src/comm/comm_session.cpp +++ b/src/comm/comm_session.cpp @@ -179,7 +179,6 @@ namespace comm if (out_msg_queue.try_dequeue(msg_to_send)) { process_outbound_message(msg_to_send); - util::sleep(1); } else { diff --git a/src/hpws/hpws.hpp b/src/hpws/hpws.hpp index 004f2137..03703de3 100644 --- a/src/hpws/hpws.hpp +++ b/src/hpws/hpws.hpp @@ -69,32 +69,27 @@ namespace hpws uint32_t max_buffer_size; bool moved = false; addr_t endpoint; - std::string get; // the get req this websocket was opened with - int control_line_fd; + std::string get; // the get req this websocket was opened with + int control_line_fd[2]; // see below in client constructor int buffer_fd[4]; // 0 1 - in buffers, 2 3 - out buffers int buffer_lock[2] = {0, 0}; // this records if buffers 2 and 3 have been sent out awaiting an ack or not void *buffer[4]; - int pending_read[2] = {0, 0}; // if we receive a read message in a non-read function we place the pending size - // here in position 0 if buffer 0 and position 1 if buffer 1, then when read is - // called we return immediately with the content - // to prevent pending buffers becoming out of order a read counter is kept incrementing for each read - // and when there are pending reads the counter at the time of the read is inserted into this array - uint64_t pending_read_counter[2] = {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL}; - uint64_t read_counter = 0; // private constructor client( std::string_view get, addr_t endpoint, - int control_line_fd, + int control_line_fd_0, // hpws -> hpcore [ hpws sends bufs to us over this line ] + int control_line_fd_1, // hpcore -> hpws [ we send bufs to hpws over this line ] uint32_t max_buffer_size, pid_t child_pid, int buffer_fd[4], void *buffer[4]) : endpoint(endpoint), - control_line_fd(control_line_fd), max_buffer_size(max_buffer_size), child_pid(child_pid), get(get) { + control_line_fd[0] = control_line_fd_0; + control_line_fd[1] = control_line_fd_1; for (int i = 0; i < 4; ++i) { this->buffer[i] = buffer[i]; @@ -113,23 +108,19 @@ namespace hpws client(client &&old) : child_pid(old.child_pid), max_buffer_size(old.max_buffer_size), endpoint(old.endpoint), - control_line_fd(old.control_line_fd), get(old.get) { old.moved = true; + for (int i = 0; i <= 1; ++i) + { + this->control_line_fd[i] = old.control_line_fd[i]; + buffer_lock[i] = old.buffer_lock[i]; + } for (int i = 0; i < 4; ++i) { this->buffer[i] = old.buffer[i]; this->buffer_fd[i] = old.buffer_fd[i]; } - - for (int i = 0; i < 2; ++i) - { - buffer_lock[i] = old.buffer_lock[i]; - pending_read[i] = old.pending_read[i]; - pending_read_counter[i] = old.pending_read_counter[i]; - read_counter = old.read_counter; - } } ~client() @@ -151,7 +142,8 @@ namespace hpws close(buffer_fd[i]); } - close(control_line_fd); + close(control_line_fd[0]); + close(control_line_fd[1]); } } @@ -169,38 +161,15 @@ namespace hpws { unsigned char buf[32]; - int bytes_read = 0; - - read_start:; - - // if during writng we got a read message it's queued as a pending read, so process that first - int do_pending_read = -1; - if (pending_read[0] || pending_read[1]) - do_pending_read = (pending_read_counter[0] > pending_read_counter[1] ? 1 : 0); - - if (do_pending_read > -1) + int bytes_read = recv(control_line_fd[0], buf, sizeof(buf), 0); + if (bytes_read < 1) { if (HPWS_DEBUG) - fprintf(stderr, "[HPWS.HPP] pending read from buffer %d\n", do_pending_read); - bytes_read = pending_read[do_pending_read]; - uint32_t len = pending_read[do_pending_read]; - pending_read[do_pending_read] = 0; - pending_read_counter[do_pending_read] = 0xFFFFFFFFFFFFFFFFULL; - return std::string_view{(const char *)(buffer[do_pending_read]), len}; - } - else - { - - bytes_read = recv(control_line_fd, buf, sizeof(buf), 0); - if (bytes_read < 1) { - if (HPWS_DEBUG) - { - perror("recv"); - fprintf(stderr, "[HPWS.HPP] bytes received %d\n", bytes_read); - } - return error{1, "[read] control line could not be read"}; // todo clean up somehow? + perror("recv"); + fprintf(stderr, "[HPWS.HPP] bytes received %d\n", bytes_read); } + return error{1, "[read] control line could not be read"}; // todo clean up somehow? } switch (buf[0]) @@ -212,8 +181,7 @@ namespace hpws if (bytes_read != 6) return error{3, "invalid buffer in 'o' command sent by hpws"}; - ++read_counter; - // there's a pending buffer for us + uint32_t len = 0; DECODE_O_SIZE(buf, len); @@ -233,36 +201,35 @@ namespace hpws } return std::string_view{(const char *)(buffer[bufno]), len}; } - case 'a': - { - if (bytes_read != 2) - return error{4, "received an ack longer than 2 bytes"}; - int bufno = buf[1] - '0'; - if (!(bufno == 0 || bufno == 1)) - return error{5, "received an ack with an invalid buffer, expecting 0 or 1"}; - // unlock the buffer - buffer_lock[bufno] = 0; - goto read_start; - } case 'c': + { return error{1000, "ws closed"}; + } default: + { fprintf(stderr, "[HPWS.HPP] read unknown control message 1: `%.*s`\n", bytes_read, buf); return error{2, "unknown control line command was sent by hpws"}; } + } } std::optional write(std::string_view to_write) { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] write called for message: `%.*s`\n", + (int)to_write.size(), to_write.data()); + // check if we have any free buffers if (buffer_lock[0] && buffer_lock[1]) { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] no free buffers while writing, waiting for an ack on control_line_fd[1]=%d\n", control_line_fd[1]); + // no free buffers, wait for a ack unsigned char buf[32]; int bytes_read = 0; - write_start:; - bytes_read = recv(control_line_fd, buf, sizeof(buf), 0); + bytes_read = recv(control_line_fd[1], buf, sizeof(buf), 0); if (bytes_read < 1) { perror("recv"); @@ -271,21 +238,6 @@ namespace hpws switch (buf[0]) { - case 'o': - { - if (bytes_read != 6) - return error{3, "invalid buffer in 'o' command sent by hpws"}; - ++read_counter; - uint32_t len = 0; - DECODE_O_SIZE(buf, len); - - int bufno = buf[1] - '0'; - if (bufno != 0 && bufno != 1) - return error{3, "invalid buffer in 'o' command sent by hpws"}; - pending_read[bufno] = len; - pending_read_counter[bufno] = read_counter; - goto write_start; - } case 'a': { if (bytes_read != 2) @@ -319,7 +271,11 @@ namespace hpws char buf[6] = {'o', (char)('0' + (bufno - 2)), 0, 0, 0, 0}; ENCODE_O_SIZE(buf, len); - if (::write(control_line_fd, buf, 6) != 6) + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] writing 'o' to control_line_fd[1]=%d\n", + control_line_fd[1]); + + if (::write(control_line_fd[1], buf, 6) != 6) return error{6, "could not write o message to control line"}; return std::nullopt; @@ -330,7 +286,7 @@ namespace hpws char msg[2] = {'a', '0'}; if (from_read.data() == buffer[1]) msg[1]++; - if (send(control_line_fd, msg, 2, 0) < 2) + if (send(control_line_fd[0], msg, 2, 0) < 2) return error{10, "could not send ack down control line"}; return std::nullopt; } @@ -354,14 +310,17 @@ namespace hpws int error_code = -1; const char *error_msg = NULL; - int fd[2] = {-1, -1}; + int fd[4] = {-1, -1, -1, -1}; // 0,1 are hpws->hpcore, 2,3 are hpcore->hpws int pid = -1; - int count_args = 12 + argv.size(); + int count_args = 14 + argv.size(); char const **argv_pass = NULL; if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd)) HPWS_CONNECT_ERROR(100, "could not create unix domain socket pair"); + if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd + 2)) + HPWS_CONNECT_ERROR(101, "could not create unix domain socket pair"); + // construct the arguments char shm_size[32]; @@ -372,6 +331,12 @@ namespace hpws if (snprintf(port_str, 6, "%d", port) <= 0) HPWS_CONNECT_ERROR(91, "couldn't write port to string"); + char cfd1[10]; + char cfd2[10]; + + snprintf(cfd1, 10, "%d", fd[1]); + snprintf(cfd2, 10, "%d", fd[3]); + argv_pass = reinterpret_cast(alloca(sizeof(char *) * count_args)); { @@ -385,7 +350,9 @@ namespace hpws argv_pass[upto++] = "--port"; argv_pass[upto++] = port_str; argv_pass[upto++] = "--cntlfd"; - argv_pass[upto++] = "3"; + argv_pass[upto++] = cfd1; + argv_pass[upto++] = "--cntlfd2"; + argv_pass[upto++] = cfd2; argv_pass[upto++] = "--get"; argv_pass[upto++] = get.data(); for (std::string_view &arg : argv) @@ -401,22 +368,27 @@ namespace hpws // --- PARENT close(fd[1]); + close(fd[3]); - int child_fd = fd[0]; + int child_fd[2] = {fd[0], fd[2]}; - int flags = fcntl(child_fd, F_GETFD, NULL); - if (flags < 0) + int flags[2] = { + fcntl(child_fd[0], F_GETFD, NULL), + fcntl(child_fd[1], F_GETFD, NULL)}; + + if (flags[0] < 0 || flags[1] < 0) HPWS_CONNECT_ERROR(101, "could not get flags from unix domain socket"); - flags |= FD_CLOEXEC; - if (fcntl(child_fd, F_SETFD, flags)) + flags[0] |= FD_CLOEXEC; + flags[1] |= FD_CLOEXEC; + if (fcntl(child_fd[0], F_SETFD, flags[0]) || fcntl(child_fd[1], F_SETFD, flags[1])) HPWS_CONNECT_ERROR(102, "could notset flags for unix domain socket"); // we will set a timeout and wait for the initial startup message from hpws client mode struct pollfd pfd; int ret; - pfd.fd = child_fd; + pfd.fd = child_fd[0]; // we receive setup events on control line 0 (hpws->hpcore) pfd.events = POLLIN; ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout @@ -430,7 +402,7 @@ namespace hpws addr_t child_addr; int bytes_read = - recv(child_fd, (unsigned char *)(&child_addr), sizeof(child_addr), 0); + recv(child_fd[0], (unsigned char *)(&child_addr), sizeof(child_addr), 0); if (bytes_read < sizeof(child_addr)) HPWS_CONNECT_ERROR(202, "received message on control line was not sizeof(addr_t)"); @@ -449,7 +421,7 @@ namespace hpws child_msg.msg_controllen = sizeof(cmsgbuf); int bytes_read = - recvmsg(child_fd, &child_msg, 0); + recvmsg(child_fd[0], &child_msg, 0); struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg); if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) HPWS_CONNECT_ERROR(203, "non-scm_rights message sent on accept child control line"); @@ -466,24 +438,35 @@ namespace hpws } } - if (HPWS_DEBUG) - fprintf(stderr, "[HPWS.HPP] waiting for 'r'\n"); + for (int i = 0; i <= 1; ++i) + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] waiting for 'r' on child_fd[%d]=%d\n", i, child_fd[i]); - // now we wait for a 'r' ready message or for the socket/client to die - ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + pfd.fd = child_fd[i]; + // now we wait for a 'r' ready message or for the socket/client to die + ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout - char rbuf[1]; - bytes_read = recv(fd[0], rbuf, sizeof(rbuf), 0); - if (bytes_read < 1) - HPWS_CONNECT_ERROR(2, "nil message sent by hpws on startup"); + char rbuf[2]; + bytes_read = recv(child_fd[i], rbuf, sizeof(rbuf), 0); + if (bytes_read < 1) + HPWS_CONNECT_ERROR(2, "nil message sent by hpws on startup"); - if (rbuf[0] != 'r') - HPWS_CONNECT_ERROR(3, "unexpected content in message sent by hpws client mode on startup"); + if (rbuf[0] != 'r') + HPWS_CONNECT_ERROR(3, "unexpected content in message sent by hpws client mode on startup"); + + if (rbuf[1] != '0' + i) + HPWS_CONNECT_ERROR(4, "received wrong r message on control fd"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] received 'r%c' on child_fd[%d]=%d\n", rbuf[1], i, child_fd[i]); + } return client{ get, child_addr, - child_fd, + child_fd[0], + child_fd[1], max_buffer_size, pid, buffer_fd, @@ -497,10 +480,16 @@ namespace hpws fork_child_init(); close(fd[0]); + close(fd[2]); // dup fd[1] into fd 3 - dup2(fd[1], 3); - close(fd[1]); + /*if (dup2(fd[1], 3) == -1) + perror("dup2 fd[1]"); + if (dup2(fd[3], 4) == -1) + perror("dup2 fd[3]"); + */ + // close(fd[1]); + // close(fd[3]); // we're assuming all fds above 3 will have close_exec flag execv(bin_path.data(), (char *const *)argv_pass); @@ -521,10 +510,9 @@ namespace hpws int status; waitpid(pid, &status, 0 /* should we use WNOHANG? */); } - if (fd[0] > 0) - close(fd[0]); - if (fd[1] > 0) - close(fd[1]); + for (int i = 0; i < 4; ++i) + if (fd[i] > 0) + close(fd[i]); return error{error_code, std::string{error_msg}}; } @@ -578,11 +566,11 @@ namespace hpws return error{code, msg}; \ } - int child_fd = -1; + int child_fd[2] = {-1, -1}; { struct msghdr child_msg = {0}; memset(&child_msg, 0, sizeof(child_msg)); - char cmsgbuf[CMSG_SPACE(sizeof(int))]; + char cmsgbuf[CMSG_SPACE(sizeof(int) * 2)]; child_msg.msg_control = cmsgbuf; child_msg.msg_controllen = sizeof(cmsgbuf); @@ -608,15 +596,18 @@ namespace hpws if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) HPWS_ACCEPT_ERROR(200, "non-scm_rights message sent on master control line"); memcpy(&child_fd, CMSG_DATA(cmsg), sizeof(child_fd)); - if (child_fd < 0) - HPWS_ACCEPT_ERROR(201, "scm_rights passed fd was negative"); + if (child_fd[0] < 0 || child_fd[1] < 0) + HPWS_ACCEPT_ERROR(201, "scm_rights passed fd/s were negative"); + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] On accept received SCM: child_fd[0] = %d, child_fd[1] = %d\n", + child_fd[0], child_fd[1]); } // read info from child control line with a timeout struct pollfd pfd; int ret; - pfd.fd = child_fd; + pfd.fd = child_fd[0]; // expect all setup messages on the hpws->hpcore controlfd (0) pfd.events = POLLIN; ret = poll(&pfd, 1, HPWS_SMALL_TIMEOUT); // 1 ms timeout @@ -627,13 +618,13 @@ namespace hpws // first thing we'll receive is the pid of the client // must not use pid_t here since we transfer across IPC channel as a uint32. uint32_t pid = 0; - if (recv(child_fd, (unsigned char *)(&pid), sizeof(pid), 0) < sizeof(pid)) + if (recv(child_fd[0], (unsigned char *)(&pid), sizeof(pid), 0) < sizeof(pid)) HPWS_ACCEPT_ERROR(212, "did not receive expected 4 byte pid of child process on accept"); // second thing we'll receive is IP address structure of the client addr_t buf; int bytes_read = - recv(child_fd, (unsigned char *)(&buf), sizeof(buf), 0); + recv(child_fd[0], (unsigned char *)(&buf), sizeof(buf), 0); if (bytes_read < sizeof(buf)) return error{202, "received message on master control line was not sizeof(sockaddr_in6)"}; @@ -649,7 +640,7 @@ namespace hpws child_msg.msg_controllen = sizeof(cmsgbuf); int bytes_read = - recvmsg(child_fd, &child_msg, 0); + recvmsg(child_fd[0], &child_msg, 0); struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg); if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) return error{203, "non-scm_rights message sent on accept child control line"}; @@ -666,30 +657,40 @@ namespace hpws } } { - if (HPWS_DEBUG) - fprintf(stderr, "[HPWS.HPP] waiting for 'r' on accept\n"); struct pollfd pfd; int ret; - pfd.fd = child_fd; - pfd.events = POLLIN; - // now we wait for a 'r' ready message or for the socket/client to die - ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout + for (int i = 0; i <= 1; ++i) + { + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] waiting for 'r' on child_fd[%d]=%d accept\n", i, child_fd[i]); + pfd.fd = child_fd[i]; + pfd.events = POLLIN; + // now we wait for a 'r' ready message or for the socket/client to die + ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout - char rbuf[1]; - bytes_read = recv(child_fd, rbuf, sizeof(rbuf), 0); - if (bytes_read < 1) - HPWS_ACCEPT_ERROR(2, "nil message sent by hpws on startup on accept"); + char rbuf[2]; + bytes_read = recv(child_fd[i], rbuf, sizeof(rbuf), 0); + if (bytes_read < 1) + HPWS_ACCEPT_ERROR(2, "nil message sent by hpws on startup on accept"); - if (rbuf[0] != 'r') - HPWS_ACCEPT_ERROR(3, "unexpected content in message sent by hpws client mode on startup"); + if (rbuf[0] != 'r') + HPWS_ACCEPT_ERROR(3, "unexpected content in message sent by hpws client mode on startup"); + + if (rbuf[1] != '0' + i) + HPWS_ACCEPT_ERROR(4, "received wrong r message on control line fd"); + + if (HPWS_DEBUG) + fprintf(stderr, "[HPWS.HPP] 'r%c' received on child_fd[%d]=%d\n", rbuf[1], i, child_fd[i]); + } } // RH TODO: accept needs a proper child cleanup on failure return client{ "", buf, - child_fd, + child_fd[0], + child_fd[1], max_buffer_size_, (pid_t)pid, buffer_fd, diff --git a/test/bin/hpws b/test/bin/hpws index f078f862..3eabf3dc 100755 Binary files a/test/bin/hpws and b/test/bin/hpws differ