Updated hpws binary to fix multi-threaded read/write.

This commit is contained in:
ravinsp
2020-10-24 06:39:07 +05:30
parent fabfdcce89
commit 86fba4e87a
3 changed files with 131 additions and 131 deletions

View File

@@ -179,7 +179,6 @@ namespace comm
if (out_msg_queue.try_dequeue(msg_to_send))
{
process_outbound_message(msg_to_send);
util::sleep(1);
}
else
{

View File

@@ -69,32 +69,27 @@ namespace hpws
uint32_t max_buffer_size;
bool moved = false;
addr_t endpoint;
std::string get; // the get req this websocket was opened with
int control_line_fd;
std::string get; // the get req this websocket was opened with
int control_line_fd[2]; // see below in client constructor
int buffer_fd[4]; // 0 1 - in buffers, 2 3 - out buffers
int buffer_lock[2] = {0, 0}; // this records if buffers 2 and 3 have been sent out awaiting an ack or not
void *buffer[4];
int pending_read[2] = {0, 0}; // if we receive a read message in a non-read function we place the pending size
// here in position 0 if buffer 0 and position 1 if buffer 1, then when read is
// called we return immediately with the content
// to prevent pending buffers becoming out of order a read counter is kept incrementing for each read
// and when there are pending reads the counter at the time of the read is inserted into this array
uint64_t pending_read_counter[2] = {0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL};
uint64_t read_counter = 0;
// private constructor
client(
std::string_view get,
addr_t endpoint,
int control_line_fd,
int control_line_fd_0, // hpws -> hpcore [ hpws sends bufs to us over this line ]
int control_line_fd_1, // hpcore -> hpws [ we send bufs to hpws over this line ]
uint32_t max_buffer_size,
pid_t child_pid,
int buffer_fd[4],
void *buffer[4]) : endpoint(endpoint),
control_line_fd(control_line_fd),
max_buffer_size(max_buffer_size),
child_pid(child_pid), get(get)
{
control_line_fd[0] = control_line_fd_0;
control_line_fd[1] = control_line_fd_1;
for (int i = 0; i < 4; ++i)
{
this->buffer[i] = buffer[i];
@@ -113,23 +108,19 @@ namespace hpws
client(client &&old) : child_pid(old.child_pid),
max_buffer_size(old.max_buffer_size),
endpoint(old.endpoint),
control_line_fd(old.control_line_fd),
get(old.get)
{
old.moved = true;
for (int i = 0; i <= 1; ++i)
{
this->control_line_fd[i] = old.control_line_fd[i];
buffer_lock[i] = old.buffer_lock[i];
}
for (int i = 0; i < 4; ++i)
{
this->buffer[i] = old.buffer[i];
this->buffer_fd[i] = old.buffer_fd[i];
}
for (int i = 0; i < 2; ++i)
{
buffer_lock[i] = old.buffer_lock[i];
pending_read[i] = old.pending_read[i];
pending_read_counter[i] = old.pending_read_counter[i];
read_counter = old.read_counter;
}
}
~client()
@@ -151,7 +142,8 @@ namespace hpws
close(buffer_fd[i]);
}
close(control_line_fd);
close(control_line_fd[0]);
close(control_line_fd[1]);
}
}
@@ -169,38 +161,15 @@ namespace hpws
{
unsigned char buf[32];
int bytes_read = 0;
read_start:;
// if during writng we got a read message it's queued as a pending read, so process that first
int do_pending_read = -1;
if (pending_read[0] || pending_read[1])
do_pending_read = (pending_read_counter[0] > pending_read_counter[1] ? 1 : 0);
if (do_pending_read > -1)
int bytes_read = recv(control_line_fd[0], buf, sizeof(buf), 0);
if (bytes_read < 1)
{
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] pending read from buffer %d\n", do_pending_read);
bytes_read = pending_read[do_pending_read];
uint32_t len = pending_read[do_pending_read];
pending_read[do_pending_read] = 0;
pending_read_counter[do_pending_read] = 0xFFFFFFFFFFFFFFFFULL;
return std::string_view{(const char *)(buffer[do_pending_read]), len};
}
else
{
bytes_read = recv(control_line_fd, buf, sizeof(buf), 0);
if (bytes_read < 1)
{
if (HPWS_DEBUG)
{
perror("recv");
fprintf(stderr, "[HPWS.HPP] bytes received %d\n", bytes_read);
}
return error{1, "[read] control line could not be read"}; // todo clean up somehow?
perror("recv");
fprintf(stderr, "[HPWS.HPP] bytes received %d\n", bytes_read);
}
return error{1, "[read] control line could not be read"}; // todo clean up somehow?
}
switch (buf[0])
@@ -212,8 +181,7 @@ namespace hpws
if (bytes_read != 6)
return error{3, "invalid buffer in 'o' command sent by hpws"};
++read_counter;
// there's a pending buffer for us
uint32_t len = 0;
DECODE_O_SIZE(buf, len);
@@ -233,36 +201,35 @@ namespace hpws
}
return std::string_view{(const char *)(buffer[bufno]), len};
}
case 'a':
{
if (bytes_read != 2)
return error{4, "received an ack longer than 2 bytes"};
int bufno = buf[1] - '0';
if (!(bufno == 0 || bufno == 1))
return error{5, "received an ack with an invalid buffer, expecting 0 or 1"};
// unlock the buffer
buffer_lock[bufno] = 0;
goto read_start;
}
case 'c':
{
return error{1000, "ws closed"};
}
default:
{
fprintf(stderr, "[HPWS.HPP] read unknown control message 1: `%.*s`\n", bytes_read, buf);
return error{2, "unknown control line command was sent by hpws"};
}
}
}
std::optional<error> write(std::string_view to_write)
{
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] write called for message: `%.*s`\n",
(int)to_write.size(), to_write.data());
// check if we have any free buffers
if (buffer_lock[0] && buffer_lock[1])
{
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] no free buffers while writing, waiting for an ack on control_line_fd[1]=%d\n", control_line_fd[1]);
// no free buffers, wait for a ack
unsigned char buf[32];
int bytes_read = 0;
write_start:;
bytes_read = recv(control_line_fd, buf, sizeof(buf), 0);
bytes_read = recv(control_line_fd[1], buf, sizeof(buf), 0);
if (bytes_read < 1)
{
perror("recv");
@@ -271,21 +238,6 @@ namespace hpws
switch (buf[0])
{
case 'o':
{
if (bytes_read != 6)
return error{3, "invalid buffer in 'o' command sent by hpws"};
++read_counter;
uint32_t len = 0;
DECODE_O_SIZE(buf, len);
int bufno = buf[1] - '0';
if (bufno != 0 && bufno != 1)
return error{3, "invalid buffer in 'o' command sent by hpws"};
pending_read[bufno] = len;
pending_read_counter[bufno] = read_counter;
goto write_start;
}
case 'a':
{
if (bytes_read != 2)
@@ -319,7 +271,11 @@ namespace hpws
char buf[6] = {'o', (char)('0' + (bufno - 2)), 0, 0, 0, 0};
ENCODE_O_SIZE(buf, len);
if (::write(control_line_fd, buf, 6) != 6)
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] writing 'o' to control_line_fd[1]=%d\n",
control_line_fd[1]);
if (::write(control_line_fd[1], buf, 6) != 6)
return error{6, "could not write o message to control line"};
return std::nullopt;
@@ -330,7 +286,7 @@ namespace hpws
char msg[2] = {'a', '0'};
if (from_read.data() == buffer[1])
msg[1]++;
if (send(control_line_fd, msg, 2, 0) < 2)
if (send(control_line_fd[0], msg, 2, 0) < 2)
return error{10, "could not send ack down control line"};
return std::nullopt;
}
@@ -354,14 +310,17 @@ namespace hpws
int error_code = -1;
const char *error_msg = NULL;
int fd[2] = {-1, -1};
int fd[4] = {-1, -1, -1, -1}; // 0,1 are hpws->hpcore, 2,3 are hpcore->hpws
int pid = -1;
int count_args = 12 + argv.size();
int count_args = 14 + argv.size();
char const **argv_pass = NULL;
if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd))
HPWS_CONNECT_ERROR(100, "could not create unix domain socket pair");
if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, fd + 2))
HPWS_CONNECT_ERROR(101, "could not create unix domain socket pair");
// construct the arguments
char shm_size[32];
@@ -372,6 +331,12 @@ namespace hpws
if (snprintf(port_str, 6, "%d", port) <= 0)
HPWS_CONNECT_ERROR(91, "couldn't write port to string");
char cfd1[10];
char cfd2[10];
snprintf(cfd1, 10, "%d", fd[1]);
snprintf(cfd2, 10, "%d", fd[3]);
argv_pass =
reinterpret_cast<char const **>(alloca(sizeof(char *) * count_args));
{
@@ -385,7 +350,9 @@ namespace hpws
argv_pass[upto++] = "--port";
argv_pass[upto++] = port_str;
argv_pass[upto++] = "--cntlfd";
argv_pass[upto++] = "3";
argv_pass[upto++] = cfd1;
argv_pass[upto++] = "--cntlfd2";
argv_pass[upto++] = cfd2;
argv_pass[upto++] = "--get";
argv_pass[upto++] = get.data();
for (std::string_view &arg : argv)
@@ -401,22 +368,27 @@ namespace hpws
// --- PARENT
close(fd[1]);
close(fd[3]);
int child_fd = fd[0];
int child_fd[2] = {fd[0], fd[2]};
int flags = fcntl(child_fd, F_GETFD, NULL);
if (flags < 0)
int flags[2] = {
fcntl(child_fd[0], F_GETFD, NULL),
fcntl(child_fd[1], F_GETFD, NULL)};
if (flags[0] < 0 || flags[1] < 0)
HPWS_CONNECT_ERROR(101, "could not get flags from unix domain socket");
flags |= FD_CLOEXEC;
if (fcntl(child_fd, F_SETFD, flags))
flags[0] |= FD_CLOEXEC;
flags[1] |= FD_CLOEXEC;
if (fcntl(child_fd[0], F_SETFD, flags[0]) || fcntl(child_fd[1], F_SETFD, flags[1]))
HPWS_CONNECT_ERROR(102, "could notset flags for unix domain socket");
// we will set a timeout and wait for the initial startup message from hpws client mode
struct pollfd pfd;
int ret;
pfd.fd = child_fd;
pfd.fd = child_fd[0]; // we receive setup events on control line 0 (hpws->hpcore)
pfd.events = POLLIN;
ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout
@@ -430,7 +402,7 @@ namespace hpws
addr_t child_addr;
int bytes_read =
recv(child_fd, (unsigned char *)(&child_addr), sizeof(child_addr), 0);
recv(child_fd[0], (unsigned char *)(&child_addr), sizeof(child_addr), 0);
if (bytes_read < sizeof(child_addr))
HPWS_CONNECT_ERROR(202, "received message on control line was not sizeof(addr_t)");
@@ -449,7 +421,7 @@ namespace hpws
child_msg.msg_controllen = sizeof(cmsgbuf);
int bytes_read =
recvmsg(child_fd, &child_msg, 0);
recvmsg(child_fd[0], &child_msg, 0);
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg);
if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS)
HPWS_CONNECT_ERROR(203, "non-scm_rights message sent on accept child control line");
@@ -466,24 +438,35 @@ namespace hpws
}
}
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] waiting for 'r'\n");
for (int i = 0; i <= 1; ++i)
{
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] waiting for 'r' on child_fd[%d]=%d\n", i, child_fd[i]);
// now we wait for a 'r' ready message or for the socket/client to die
ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout
pfd.fd = child_fd[i];
// now we wait for a 'r' ready message or for the socket/client to die
ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout
char rbuf[1];
bytes_read = recv(fd[0], rbuf, sizeof(rbuf), 0);
if (bytes_read < 1)
HPWS_CONNECT_ERROR(2, "nil message sent by hpws on startup");
char rbuf[2];
bytes_read = recv(child_fd[i], rbuf, sizeof(rbuf), 0);
if (bytes_read < 1)
HPWS_CONNECT_ERROR(2, "nil message sent by hpws on startup");
if (rbuf[0] != 'r')
HPWS_CONNECT_ERROR(3, "unexpected content in message sent by hpws client mode on startup");
if (rbuf[0] != 'r')
HPWS_CONNECT_ERROR(3, "unexpected content in message sent by hpws client mode on startup");
if (rbuf[1] != '0' + i)
HPWS_CONNECT_ERROR(4, "received wrong r message on control fd");
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] received 'r%c' on child_fd[%d]=%d\n", rbuf[1], i, child_fd[i]);
}
return client{
get,
child_addr,
child_fd,
child_fd[0],
child_fd[1],
max_buffer_size,
pid,
buffer_fd,
@@ -497,10 +480,16 @@ namespace hpws
fork_child_init();
close(fd[0]);
close(fd[2]);
// dup fd[1] into fd 3
dup2(fd[1], 3);
close(fd[1]);
/*if (dup2(fd[1], 3) == -1)
perror("dup2 fd[1]");
if (dup2(fd[3], 4) == -1)
perror("dup2 fd[3]");
*/
// close(fd[1]);
// close(fd[3]);
// we're assuming all fds above 3 will have close_exec flag
execv(bin_path.data(), (char *const *)argv_pass);
@@ -521,10 +510,9 @@ namespace hpws
int status;
waitpid(pid, &status, 0 /* should we use WNOHANG? */);
}
if (fd[0] > 0)
close(fd[0]);
if (fd[1] > 0)
close(fd[1]);
for (int i = 0; i < 4; ++i)
if (fd[i] > 0)
close(fd[i]);
return error{error_code, std::string{error_msg}};
}
@@ -578,11 +566,11 @@ namespace hpws
return error{code, msg}; \
}
int child_fd = -1;
int child_fd[2] = {-1, -1};
{
struct msghdr child_msg = {0};
memset(&child_msg, 0, sizeof(child_msg));
char cmsgbuf[CMSG_SPACE(sizeof(int))];
char cmsgbuf[CMSG_SPACE(sizeof(int) * 2)];
child_msg.msg_control = cmsgbuf;
child_msg.msg_controllen = sizeof(cmsgbuf);
@@ -608,15 +596,18 @@ namespace hpws
if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS)
HPWS_ACCEPT_ERROR(200, "non-scm_rights message sent on master control line");
memcpy(&child_fd, CMSG_DATA(cmsg), sizeof(child_fd));
if (child_fd < 0)
HPWS_ACCEPT_ERROR(201, "scm_rights passed fd was negative");
if (child_fd[0] < 0 || child_fd[1] < 0)
HPWS_ACCEPT_ERROR(201, "scm_rights passed fd/s were negative");
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] On accept received SCM: child_fd[0] = %d, child_fd[1] = %d\n",
child_fd[0], child_fd[1]);
}
// read info from child control line with a timeout
struct pollfd pfd;
int ret;
pfd.fd = child_fd;
pfd.fd = child_fd[0]; // expect all setup messages on the hpws->hpcore controlfd (0)
pfd.events = POLLIN;
ret = poll(&pfd, 1, HPWS_SMALL_TIMEOUT); // 1 ms timeout
@@ -627,13 +618,13 @@ namespace hpws
// first thing we'll receive is the pid of the client
// must not use pid_t here since we transfer across IPC channel as a uint32.
uint32_t pid = 0;
if (recv(child_fd, (unsigned char *)(&pid), sizeof(pid), 0) < sizeof(pid))
if (recv(child_fd[0], (unsigned char *)(&pid), sizeof(pid), 0) < sizeof(pid))
HPWS_ACCEPT_ERROR(212, "did not receive expected 4 byte pid of child process on accept");
// second thing we'll receive is IP address structure of the client
addr_t buf;
int bytes_read =
recv(child_fd, (unsigned char *)(&buf), sizeof(buf), 0);
recv(child_fd[0], (unsigned char *)(&buf), sizeof(buf), 0);
if (bytes_read < sizeof(buf))
return error{202, "received message on master control line was not sizeof(sockaddr_in6)"};
@@ -649,7 +640,7 @@ namespace hpws
child_msg.msg_controllen = sizeof(cmsgbuf);
int bytes_read =
recvmsg(child_fd, &child_msg, 0);
recvmsg(child_fd[0], &child_msg, 0);
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg);
if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS)
return error{203, "non-scm_rights message sent on accept child control line"};
@@ -666,30 +657,40 @@ namespace hpws
}
}
{
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] waiting for 'r' on accept\n");
struct pollfd pfd;
int ret;
pfd.fd = child_fd;
pfd.events = POLLIN;
// now we wait for a 'r' ready message or for the socket/client to die
ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout
for (int i = 0; i <= 1; ++i)
{
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] waiting for 'r' on child_fd[%d]=%d accept\n", i, child_fd[i]);
pfd.fd = child_fd[i];
pfd.events = POLLIN;
// now we wait for a 'r' ready message or for the socket/client to die
ret = poll(&pfd, 1, HPWS_LONG_TIMEOUT); // default= 1500 ms timeout
char rbuf[1];
bytes_read = recv(child_fd, rbuf, sizeof(rbuf), 0);
if (bytes_read < 1)
HPWS_ACCEPT_ERROR(2, "nil message sent by hpws on startup on accept");
char rbuf[2];
bytes_read = recv(child_fd[i], rbuf, sizeof(rbuf), 0);
if (bytes_read < 1)
HPWS_ACCEPT_ERROR(2, "nil message sent by hpws on startup on accept");
if (rbuf[0] != 'r')
HPWS_ACCEPT_ERROR(3, "unexpected content in message sent by hpws client mode on startup");
if (rbuf[0] != 'r')
HPWS_ACCEPT_ERROR(3, "unexpected content in message sent by hpws client mode on startup");
if (rbuf[1] != '0' + i)
HPWS_ACCEPT_ERROR(4, "received wrong r message on control line fd");
if (HPWS_DEBUG)
fprintf(stderr, "[HPWS.HPP] 'r%c' received on child_fd[%d]=%d\n", rbuf[1], i, child_fd[i]);
}
}
// RH TODO: accept needs a proper child cleanup on failure
return client{
"",
buf,
child_fd,
child_fd[0],
child_fd[1],
max_buffer_size_,
(pid_t)pid,
buffer_fd,

Binary file not shown.