diff --git a/src/hpsh/hpsh.cpp b/src/hpsh/hpsh.cpp index d3a9ea96..bad45c07 100644 --- a/src/hpsh/hpsh.cpp +++ b/src/hpsh/hpsh.cpp @@ -2,8 +2,8 @@ namespace hpsh { - constexpr const char *HPSH_CTR_SH = "sh"; - constexpr const char *HPSH_CTR_TERMINATE = "terminate"; + constexpr uint8_t HPSH_CTRL_TERMINATE = 0; + constexpr uint8_t HPSH_CTRL_SH = 1; constexpr uint32_t POLL_TIMEOUT = 1000; constexpr uint32_t READ_BUFFER_SIZE = 128 * 1024; @@ -94,8 +94,7 @@ namespace hpsh close(ctx.control_fds[0]); for (const auto &command : ctx.commands) { - close(command.child_fds[0]); - close(command.child_fds[1]); + close(command.out_fd); } if (ctx.hpsh_pid > 0) @@ -146,7 +145,7 @@ namespace hpsh int send_terminate_message() { - return (write(ctx.control_fds[1], HPSH_CTR_TERMINATE, 10) < 0) ? -1 : 0; + return (write(ctx.control_fds[1], &HPSH_CTRL_TERMINATE, 1) < 0) ? -1 : 0; } void remove_user_commands(std::string_view user_pubkey) @@ -159,7 +158,7 @@ namespace hpsh if (itr->user_pubkey == user_pubkey) { // Close the file descriptor and remove the command from context. - close(itr->child_fds[1]); + close(itr->out_fd); itr = ctx.commands.erase(itr); } else @@ -180,69 +179,49 @@ namespace hpsh return -2; } + std::string buffer; + buffer.resize(message.size() + 1); + buffer[0] = HPSH_CTRL_SH; + memcpy(buffer.data() + 1, message.data(), message.size()); + // Send the hpsh request header. - if (write(ctx.control_fds[1], HPSH_CTR_SH, 3) < 0) + if (write(ctx.control_fds[1], buffer.data(), message.size() + 1) < 0) { LOG_ERROR << errno << ": Error writing header message to control fd."; return -1; } - // Create a socket pair to communicate for the hpsh request. - int child_fds[2]; - if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, child_fds) == -1) + // Read the control message which will contain the socket file descriptor. + struct msghdr child_msg = {0}; + memset(&child_msg, 0, sizeof(child_msg)); + char cmsgbuf[CMSG_SPACE(sizeof(int))]; + child_msg.msg_control = cmsgbuf; + child_msg.msg_controllen = sizeof(cmsgbuf); + + recvmsg(ctx.control_fds[1], &child_msg, 0); + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&child_msg); + + // Skip if the message does not has file descriptor scm rights. + if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) { - LOG_ERROR << errno << ": Error initializing socket pair."; + LOG_ERROR << "Message sent on control line from hpsh has non-scm_rights."; return -1; } - // Prepare and send the child socket file descriptor with scm rights. - struct msghdr msg = {0}; - struct cmsghdr *cmsg; - char iobuf[1]; - struct iovec io = { - .iov_base = iobuf, - .iov_len = sizeof(iobuf)}; - union - { /* Ancillary data buffer, wrapped in a union - in order to ensure it is suitably aligned */ - char buf[CMSG_SPACE(sizeof(int))]; - struct cmsghdr align; - } u; + int out_fd = -1; + memcpy(&out_fd, CMSG_DATA(cmsg), sizeof(out_fd)); - msg.msg_iov = &io; - msg.msg_iovlen = 1; - msg.msg_control = u.buf; - msg.msg_controllen = sizeof(u.buf); - cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(int)); - memcpy(CMSG_DATA(cmsg), child_fds, sizeof(int)); - - if (sendmsg(ctx.control_fds[1], &msg, 0) < 0) + if (out_fd <= 0) { - LOG_ERROR << errno << ": Error writing to control fd."; - close(child_fds[0]); - close(child_fds[1]); + LOG_ERROR << "Invalid file descriptor receives on control line from hpsh"; return -1; } - // Write the request message to the child socket. - if (write(child_fds[1], message.data(), message.size()) < 0) - { - LOG_ERROR << errno << ": Error writing to child fd."; - close(child_fds[0]); - close(child_fds[1]); - return -1; - } - - // Close the child end of the socket. - close(child_fds[0]); - // Add the command to the context. { std::scoped_lock lock(ctx.command_mutex); - ctx.commands.push_back(command_context{std::string(id), std::string(user_pubkey), {child_fds[0], child_fds[1]}}); + ctx.commands.push_back(command_context{std::string(id), std::string(user_pubkey), out_fd}); } return 0; @@ -266,16 +245,19 @@ namespace hpsh break; struct pollfd pfd; - pfd.fd = itr->child_fds[1]; + pfd.fd = itr->out_fd; pfd.events = POLLIN; + bool remove = false; + // If child fd has data to read handle them. - if (poll(&pfd, 1, POLL_TIMEOUT) == -1) + const int poll_res = poll(&pfd, 1, POLL_TIMEOUT); + if (poll_res == -1) { LOG_ERROR << errno << ": Error in poll"; - continue; + remove = true; } - else if (pfd.revents & POLLIN) + else if (poll_res > 0 && (pfd.revents & POLLIN)) { // Read the response and send to the user. std::string response; @@ -303,11 +285,26 @@ namespace hpsh } else if (res == -1) { - LOG_ERROR << errno << ": Error reading from fd"; + LOG_ERROR << errno << ": Error reading from fd."; + remove = true; + } + else + { + LOG_DEBUG << "Hpsh has closed the connection."; + remove = true; } } - itr++; + if (remove) + { + // Close the file descriptor and remove the command from context. + close(itr->out_fd); + itr = ctx.commands.erase(itr); + } + else + { + itr++; + } } } util::sleep(100); diff --git a/src/hpsh/hpsh.hpp b/src/hpsh/hpsh.hpp index 6c859897..ef9ad675 100644 --- a/src/hpsh/hpsh.hpp +++ b/src/hpsh/hpsh.hpp @@ -12,7 +12,7 @@ namespace hpsh { std::string id; std::string user_pubkey; - int child_fds[2]; + int out_fd; }; struct hpsh_context diff --git a/test/bin/hpsh b/test/bin/hpsh index ce252470..1c76d10b 100755 Binary files a/test/bin/hpsh and b/test/bin/hpsh differ