Port to pcre2
[miniexpect.git] / miniexpect.c
index fef03fd..77d781c 100644 (file)
 #include <stdio.h>
 #include <stdlib.h>
 #include <stdarg.h>
+#include <string.h>
+#include <ctype.h>
 #include <fcntl.h>
 #include <unistd.h>
+#include <signal.h>
 #include <poll.h>
 #include <errno.h>
 #include <termios.h>
 #include <sys/wait.h>
 #include <sys/time.h>
 
-#include <pcre.h>
+#define PCRE2_CODE_UNIT_WIDTH 8
+#include <pcre2.h>
 
 #include "miniexpect.h"
 
-#define DEBUG 0
+static void debug_buffer (FILE *, const char *);
 
 static mexp_h *
 create_handle (void)
@@ -53,6 +57,8 @@ create_handle (void)
   h->pcre_error = 0;
   h->buffer = NULL;
   h->len = h->alloc = 0;
+  h->next_match = -1;
+  h->debug_fp = NULL;
   h->user1 = h->user2 = h->user3 = NULL;
 
   return h;
@@ -64,6 +70,7 @@ clear_buffer (mexp_h *h)
   free (h->buffer);
   h->buffer = NULL;
   h->alloc = h->len = 0;
+  h->next_match = -1;
 }
 
 int
@@ -80,11 +87,13 @@ mexp_close (mexp_h *h)
       return -1;
   }
 
+  free (h);
+
   return status;
 }
 
 mexp_h *
-mexp_spawnl (const char *file, const char *arg, ...)
+mexp_spawnlf (unsigned flags, const char *file, const char *arg, ...)
 {
   char **argv, **new_argv;
   size_t i;
@@ -102,21 +111,23 @@ mexp_spawnl (const char *file, const char *arg, ...)
     new_argv = realloc (argv, sizeof (char *) * (i+1));
     if (new_argv == NULL) {
       free (argv);
+      va_end (args);
       return NULL;
     }
     argv = new_argv;
     argv[i] = (char *) arg;
   }
 
-  h = mexp_spawnv (file, argv);
+  h = mexp_spawnvf (flags, file, argv);
   free (argv);
+  va_end (args);
   return h;
 }
 
 mexp_h *
-mexp_spawnv (const char *file, char **argv)
+mexp_spawnvf (unsigned flags, const char *file, char **argv)
 {
-  mexp_h *h;
+  mexp_h *h = NULL;
   int fd = -1;
   int err;
   char slave[1024];
@@ -146,9 +157,25 @@ mexp_spawnv (const char *file, char **argv)
     goto error;
 
   if (pid == 0) {               /* Child. */
-    struct termios terminal_settings;
     int slave_fd;
 
+    if (!(flags & MEXP_SPAWN_KEEP_SIGNALS)) {
+      struct sigaction sa;
+      int i;
+
+      /* Remove all signal handlers.  See the justification here:
+       * https://www.redhat.com/archives/libvir-list/2008-August/msg00303.html
+       * We don't mask signal handlers yet, so this isn't completely
+       * race-free, but better than not doing it at all.
+       */
+      memset (&sa, 0, sizeof sa);
+      sa.sa_handler = SIG_DFL;
+      sa.sa_flags = 0;
+      sigemptyset (&sa.sa_mask);
+      for (i = 1; i < NSIG; ++i)
+        sigaction (i, &sa, NULL);
+    }
+
     setsid ();
 
     /* Open the slave side of the pty.  We must do this in the child
@@ -158,10 +185,14 @@ mexp_spawnv (const char *file, char **argv)
     if (slave_fd == -1)
       goto error;
 
-    /* Set raw mode. */
-    tcgetattr (slave_fd, &terminal_settings);
-    cfmakeraw (&terminal_settings);
-    tcsetattr (slave_fd, TCSANOW, &terminal_settings);
+    if (!(flags & MEXP_SPAWN_COOKED_MODE)) {
+      struct termios termios;
+
+      /* Set raw mode. */
+      tcgetattr (slave_fd, &termios);
+      cfmakeraw (&termios);
+      tcsetattr (slave_fd, TCSANOW, &termios);
+    }
 
     /* Set up stdin, stdout, stderr to point to the pty. */
     dup2 (slave_fd, 0);
@@ -174,6 +205,21 @@ mexp_spawnv (const char *file, char **argv)
      */
     close (fd);
 
+    if (!(flags & MEXP_SPAWN_KEEP_FDS)) {
+      int i, max_fd;
+
+      /* Close all other file descriptors.  This ensures that we don't
+       * hold open (eg) pipes from the parent process.
+       */
+      max_fd = sysconf (_SC_OPEN_MAX);
+      if (max_fd == -1)
+        max_fd = 1024;
+      if (max_fd > 65536)
+        max_fd = 65536;      /* bound the amount of work we do here */
+      for (i = 3; i < max_fd; ++i)
+        close (i);
+    }
+
     /* Run the subprocess. */
     execvp (file, argv);
     perror (file);
@@ -192,12 +238,15 @@ mexp_spawnv (const char *file, char **argv)
     close (fd);
   if (pid > 0)
     waitpid (pid, NULL, 0);
+  if (h != NULL)
+    mexp_close (h);
   errno = err;
   return NULL;
 }
 
 enum mexp_status
-mexp_expect (mexp_h *h, const mexp_regexp *regexps, int *ovector, int ovecsize)
+mexp_expect (mexp_h *h, const mexp_regexp *regexps,
+             pcre2_match_data *match_data)
 {
   time_t start_t, now_t;
   int timeout;
@@ -207,8 +256,19 @@ mexp_expect (mexp_h *h, const mexp_regexp *regexps, int *ovector, int ovecsize)
 
   time (&start_t);
 
-  /* Clear the read buffer. */
-  clear_buffer (h);
+  if (h->next_match == -1) {
+    /* Fully clear the buffer, then read. */
+    clear_buffer (h);
+  } else {
+    /* See the comment in the manual about h->next_match.  We have
+     * some data remaining in the buffer, so begin by matching that.
+     */
+    memmove (&h->buffer[0], &h->buffer[h->next_match], h->len - h->next_match);
+    h->len -= h->next_match;
+    h->buffer[h->len] = '\0';
+    h->next_match = -1;
+    goto try_match;
+  }
 
   for (;;) {
     /* If we've got a timeout then work out how many seconds are left.
@@ -228,9 +288,8 @@ mexp_expect (mexp_h *h, const mexp_regexp *regexps, int *ovector, int ovecsize)
     pfds[0].events = POLLIN;
     pfds[0].revents = 0;
     r = poll (pfds, 1, timeout);
-#if DEBUG
-    fprintf (stderr, "DEBUG: poll returned %d\n", r);
-#endif
+    if (h->debug_fp)
+      fprintf (h->debug_fp, "DEBUG: poll returned %d\n", r);
     if (r == -1)
       return MEXP_ERROR;
 
@@ -250,9 +309,8 @@ mexp_expect (mexp_h *h, const mexp_regexp *regexps, int *ovector, int ovecsize)
       h->alloc += h->read_size;
     }
     rs = read (h->fd, h->buffer + h->len, h->read_size);
-#if DEBUG
-    fprintf (stderr, "DEBUG: read returned %zd\n", rs);
-#endif
+    if (h->debug_fp)
+      fprintf (h->debug_fp, "DEBUG: read returned %zd\n", rs);
     if (rs == -1) {
       /* Annoyingly on Linux (I'm fairly sure this is a bug) if the
        * writer closes the connection, the entire pty is destroyed,
@@ -268,11 +326,14 @@ mexp_expect (mexp_h *h, const mexp_regexp *regexps, int *ovector, int ovecsize)
     /* We read something. */
     h->len += rs;
     h->buffer[h->len] = '\0';
-#if DEBUG
-    fprintf (stderr, "DEBUG: read %zd bytes from pty\n", rs);
-    fprintf (stderr, "DEBUG: buffer content: %s\n", h->buffer);
-#endif
+    if (h->debug_fp) {
+      fprintf (h->debug_fp, "DEBUG: read %zd bytes from pty\n", rs);
+      fprintf (h->debug_fp, "DEBUG: buffer content: ");
+      debug_buffer (h->debug_fp, h->buffer);
+      fprintf (h->debug_fp, "\n");
+    }
 
+  try_match:
     /* See if there is a full or partial match against any regexp. */
     if (regexps) {
       size_t i;
@@ -281,25 +342,36 @@ mexp_expect (mexp_h *h, const mexp_regexp *regexps, int *ovector, int ovecsize)
       assert (h->buffer != NULL);
 
       for (i = 0; regexps[i].r > 0; ++i) {
-        int options = regexps[i].options | PCRE_PARTIAL_SOFT;
+        const int options = regexps[i].options | PCRE2_PARTIAL_SOFT;
 
-        r = pcre_exec (regexps[i].re, regexps[i].extra,
-                       h->buffer, (int)h->len, 0,
-                       options,
-                       ovector, ovecsize);
+        r = pcre2_match (regexps[i].re,
+                         (PCRE2_SPTR) h->buffer, (int)h->len, 0,
+                         options, match_data, NULL);
         h->pcre_error = r;
 
         if (r >= 0) {
           /* A full match. */
+          const PCRE2_SIZE *ovector = NULL;
+
+          if (match_data)
+            ovector = pcre2_get_ovector_pointer (match_data);
+
+          if (ovector != NULL && ovector[1] >= 0)
+            h->next_match = ovector[1];
+          else
+            h->next_match = -1;
+          if (h->debug_fp)
+            fprintf (h->debug_fp, "DEBUG: next_match at buffer offset %zu\n",
+                     h->next_match);
           return regexps[i].r;
         }
 
-        else if (r == PCRE_ERROR_NOMATCH) {
+        else if (r == PCRE2_ERROR_NOMATCH) {
           /* No match at all. */
           /* (nothing here) */
         }
 
-        else if (r == PCRE_ERROR_PARTIAL) {
+        else if (r == PCRE2_ERROR_PARTIAL) {
           /* Partial match.  Keep the buffer and keep reading. */
           can_clear_buffer = 0;
         }
@@ -320,26 +392,32 @@ mexp_expect (mexp_h *h, const mexp_regexp *regexps, int *ovector, int ovecsize)
   }
 }
 
-int
-mexp_printf (mexp_h *h, const char *fs, ...)
+static int mexp_vprintf (mexp_h *h, int password, const char *fs, va_list args)
+  __attribute__((format(printf,3,0)));
+
+static int
+mexp_vprintf (mexp_h *h, int password, const char *fs, va_list args)
 {
-  va_list args;
   char *msg;
   int len;
   size_t n;
   ssize_t r;
   char *p;
 
-  va_start (args, fs);
   len = vasprintf (&msg, fs, args);
-  va_end (args);
 
   if (len < 0)
     return -1;
 
-#if DEBUG
-  fprintf (stderr, "DEBUG: writing: %s\n", msg);
-#endif
+  if (h->debug_fp) {
+    if (!password) {
+      fprintf (h->debug_fp, "DEBUG: writing: ");
+      debug_buffer (h->debug_fp, msg);
+      fprintf (h->debug_fp, "\n");
+    }
+    else
+      fprintf (h->debug_fp, "DEBUG: writing the password\n");
+  }
 
   n = len;
   p = msg;
@@ -356,3 +434,58 @@ mexp_printf (mexp_h *h, const char *fs, ...)
   free (msg);
   return len;
 }
+
+int
+mexp_printf (mexp_h *h, const char *fs, ...)
+{
+  int r;
+  va_list args;
+
+  va_start (args, fs);
+  r = mexp_vprintf (h, 0, fs, args);
+  va_end (args);
+  return r;
+}
+
+int
+mexp_printf_password (mexp_h *h, const char *fs, ...)
+{
+  int r;
+  va_list args;
+
+  va_start (args, fs);
+  r = mexp_vprintf (h, 1, fs, args);
+  va_end (args);
+  return r;
+}
+
+int
+mexp_send_interrupt (mexp_h *h)
+{
+  return write (h->fd, "\003", 1);
+}
+
+/* Print escaped buffer to fp. */
+static void
+debug_buffer (FILE *fp, const char *buf)
+{
+  while (*buf) {
+    if (isprint (*buf))
+      fputc (*buf, fp);
+    else {
+      switch (*buf) {
+      case '\0': fputs ("\\0", fp); break;
+      case '\a': fputs ("\\a", fp); break;
+      case '\b': fputs ("\\b", fp); break;
+      case '\f': fputs ("\\f", fp); break;
+      case '\n': fputs ("\\n", fp); break;
+      case '\r': fputs ("\\r", fp); break;
+      case '\t': fputs ("\\t", fp); break;
+      case '\v': fputs ("\\v", fp); break;
+      default:
+        fprintf (fp, "\\x%x", (unsigned char) *buf);
+      }
+    }
+    buf++;
+  }
+}