Document testing method.
[virt-hostinfo.git] / hostinfod / commands.c
index 708b4a8..d2a1d2d 100644 (file)
 
 #include <stdio.h>
 #include <stdlib.h>
+#include <stdarg.h>
 #include <assert.h>
 #include <ctype.h>
 #include <unistd.h>
 #include <time.h>
+#include <math.h>
+
+#include <apr_general.h>
+#include <apr_pools.h>
+#include <apr_hash.h>
+#include <apr_strings.h>
 
 #include "hostinfod.h"
 
 apr_hash_t *commands = NULL;
 
 static const char *string_of_arg_type (enum arg_type);
-static void init_commands_hash (void);
 static int not_printable (const char *str);
+static int contains_crlf (const char *str);
+static char *parse_c_string (struct guest_description *hval, const char *str, size_t *posn);
+static int parse_long (struct guest_description *hval, const char *str, size_t *posn, long *ret);
 
 void
 execute_command (const struct timespec *now,
@@ -50,9 +59,12 @@ execute_command (const struct timespec *now,
   size_t len, i, j;
   int neg;
   char *cmd;
-  apr_array_header_t args;
+  apr_array_header_t *args;
   struct arg arg;
   command_fn fn;
+  int enabled;
+  double interval;
+  struct timespec *last;
 
   debug ("%s: %s", hval->name, command);
 
@@ -76,8 +88,8 @@ execute_command (const struct timespec *now,
    */
   len = strlen (command);
   if (len < 1 || len > 4094 || not_printable (command)) {
-    warning ("%s: command too short, too long or contained unprintable chars",
-            hval->name);
+    warning ("%s: command too short, too long or contained unprintable chars (len = %zu)",
+            hval->name, len);
     send_error (hval, 400);
     return;
   }
@@ -118,7 +130,11 @@ execute_command (const struct timespec *now,
 
     case '"':                  /* string literal */
       arg.type = arg_type_string;
-      XXX;
+      arg.u.str = parse_c_string (hval, command, &i);
+      if (!arg.u.str) {
+       send_error (hval, 400);
+       return;
+      }
       break;
 
     case '-':                  /* integer literal */
@@ -134,7 +150,10 @@ execute_command (const struct timespec *now,
     case '8':
     case '9':
       arg.type = arg_type_int;
-      XXX ();
+      if (parse_long (hval, command, &i, &arg.u.i) == -1) {
+       send_error (hval, 400);
+       return;
+      }
       break;
 
     case 't':                  /* boolean */
@@ -147,20 +166,20 @@ execute_command (const struct timespec *now,
        i += 4;
       } else if (strncasecmp (&command[i], "false", 5) == 0) {
        arg.u.i = 0;
-       i += 4;
+       i += 5;
       } else
        goto unknown_arg;
       break;
 
     default:
     unknown_arg:
-      warning ("%s: unknown or malformed argument starting at position %d ('%c')",
+      warning ("%s: unknown or malformed argument starting at position %zu ('%c')",
               hval->name, i, command[i]);
       send_error (hval, 400);
       return;
     } /* switch */
 
-    APR_ARRAY_PUSH (args, struct arg) = *arg;
+    APR_ARRAY_PUSH (args, struct arg) = arg;
   }
 
   if (verbose) {
@@ -172,11 +191,11 @@ execute_command (const struct timespec *now,
       switch (arg.type) {
       case arg_type_bool:
       case arg_type_int:
-       debug ("%s: arg %d : %s = %ld",
+       debug ("%s: arg %zu : %s = %ld",
               hval->name, i, string_of_arg_type (arg.type), arg.u.i);
        break;
       case arg_type_string:
-       debug ("%s: arg %d : %s = %s",
+       debug ("%s: arg %zu : %s = %s",
               hval->name, i, string_of_arg_type (arg.type), arg.u.str);
        break;
       }
@@ -193,8 +212,42 @@ execute_command (const struct timespec *now,
   /* Before dispatching the command, check the command is enabled
    * and guest is not calling it too frequently.
    */
-  error ("XXXXXXX enabled check not implemented XXXXXXX");
-  error ("XXXXXXX frequency check not implemented XXXXXXX");
+  check_guests_file (hval, cmd, &interval, &enabled);
+
+  if (!enabled) {
+    warning ("%s: guest tried disabled command '%s'", hval->name, cmd);
+    send_error (hval, 401);
+    return;
+  }
+
+  last = apr_hash_get (hval->lasttime, cmd, APR_HASH_KEY_STRING);
+  if (last) {
+    struct timespec timediff;
+    double interval_int, interval_frac;
+    struct timespec interval_ts;
+
+    diff_timespec (&timediff, now, last);
+
+    interval_frac = modf (interval, &interval_int);
+    interval_ts.tv_sec = interval_int;
+    interval_ts.tv_nsec = interval_frac * 1000000000;
+
+    debug ("%s: %s: interval %ds %ldns, time since last %ds %ldns",
+          hval->name, cmd,
+          (int) interval_ts.tv_sec, interval_ts.tv_nsec,
+          (int) timediff.tv_sec, timediff.tv_nsec);
+
+    if (interval_ts.tv_sec > timediff.tv_sec ||
+       (interval_ts.tv_sec == timediff.tv_sec &&
+        interval_ts.tv_nsec > timediff.tv_nsec)) {
+      warning ("%s: command '%s' exceeded interval allowed", hval->name, cmd);
+      send_error (hval, 406);
+      return;
+    }
+  }
+
+  last = apr_pmemdup (hval->pool, now, sizeof *now);
+  apr_hash_set (hval->lasttime, cmd, APR_HASH_KEY_STRING, last);
 
   /* Dispatch the command. */
   fn (hval, cmd, args);
@@ -210,22 +263,188 @@ not_printable (const char *str)
 
   while ((c = *(unsigned char *)str)) {
     if (c < 32 || c > 126)
-      return 0;
+      return 1;
+    str++;
+  }
+  return 0;
+}
+
+static int
+contains_crlf (const char *str)
+{
+  int c;
+
+  while ((c = *(unsigned char *)str)) {
+    if (c == '\r' || c == '\n')
+      return 1;
     str++;
   }
-  return 1;
+  return 0;
+}
+
+static int
+hexval (int c)
+{
+  switch (c) {
+  case '0': case '1': case '2': case '3': case '4':
+  case '5': case '6': case '7': case '8': case '9':
+    return c - '0';
+  case 'a': case 'b': case 'c': case 'd': case 'e': case 'f':
+    return c - 'a' + 10;
+  case 'A': case 'B': case 'C': case 'D': case 'E': case 'F':
+    return c - 'A' + 10;
+  default:
+    abort ();
+  }
+}
+
+static char *
+parse_c_string (struct guest_description *hval, const char *str, size_t *posn)
+{
+  apr_array_header_t *r;
+  int c;
+
+  r = apr_array_make (hval->rpool, 0, 1);
+#define APPEND(c) APR_ARRAY_PUSH (r,char) = (c);
+
+  assert (str[*posn] == '"');
+  (*posn)++;
+
+  while (str[*posn] != '"') {
+    if (str[*posn] == '\\') {  /* Start of \escape sequence. */
+      switch (str[*posn+1]) {
+      case '0': case '1': case '2': case '3':
+       if ((str[*posn+2] >= '0' && str[*posn+2] <= '7') &&
+           (str[*posn+3] >= '0' && str[*posn+3] <= '7')) {
+         c = (str[*posn+1] - '0') * 0100;
+         c += (str[*posn+2] - '0') * 010;
+         c += str[*posn+3] - '0';
+         if (c != 0) {
+           APPEND (c);
+           *posn += 4;
+         } else {
+           warning ("%s: \\0 cannot appear in string literal", hval->name);
+           return NULL;
+         }
+         break;
+       } else {
+         warning ("%s: invalid octal sequence in string literal", hval->name);
+         return NULL;
+       }
+      case 'x':
+       if (isxdigit (str[*posn+2]) && isxdigit (str[*posn+3])) {
+         c = hexval (str[*posn+2]) * 0x10;
+         c += hexval (str[*posn+3]);
+         if (c != 0) {
+           APPEND (c);
+           *posn += 4;
+         } else {
+           warning ("%s: \\0 cannot appear in string literal", hval->name);
+           return NULL;
+         }
+         break;
+       } else {
+         warning ("%s: invalid hex sequence in string literal", hval->name);
+         return NULL;
+       }
+      case 'a':
+       APPEND ('\a');
+       *posn += 2;
+       break;
+      case 'b':
+       APPEND ('\b');
+       *posn += 2;
+       break;
+      case 'f':
+       APPEND ('\f');
+       *posn += 2;
+       break;
+      case 'n':
+       APPEND ('\n');
+       *posn += 2;
+       break;
+      case 'r':
+       APPEND ('\r');
+       *posn += 2;
+       break;
+      case 't':
+       APPEND ('\t');
+       *posn += 2;
+       break;
+      case 'v':
+       APPEND ('\v');
+       *posn += 2;
+       break;
+      case '"':
+       APPEND ('"');
+       *posn += 2;
+       break;
+      case '\\':
+       APPEND ('\\');
+       *posn += 2;
+       break;
+
+      default:
+       warning ("%s: unterminated escape sequence in string literal",
+                hval->name);
+       return NULL;
+      }
+    } else if (str[*posn] == '\0') { /* Unterminated string literal. */
+      warning ("%s: unterminated string literal in request", hval->name);
+      return NULL;
+    } else {                   /* Ordinary character. */
+      APPEND (str[*posn]);
+      (*posn)++;
+    }
+  }
+
+  /* Finish off the string and return it. */
+  APPEND ('\0');
+  (*posn)++;                   /* Skips over the final quote. */
+
+  return r->elts;
+}
+
+static int
+parse_long (struct guest_description *hval,
+           const char *str, size_t *posn, long *ret)
+{
+  error ("XXXXXXX parse_long not implemented XXXXXXX");
+  return -1;
 }
 
 /* For single line replies. */
 void
 send_reply (struct guest_description *hval, int code, const char *fs, ...)
 {
-  const char *msg;
+  va_list args;
+  char *msg;
 
   /* All success codes must be 2xx. */
   assert (code >= 200 && code < 300);
 
-  XXX;
+  va_start (args, fs);
+  msg = apr_pvsprintf (hval->rpool, fs, args);
+  va_end (args);
+
+  /* The result string must not contain any CR or LF characters.  If
+   * not it's an internal error in the command, or else the caller has
+   * (somehow) managed to pass a bad string through.
+   */
+  if (contains_crlf (msg)) {
+    error ("%s: send_reply: refusing the send a reply message containing CR/LF characters.  This is a serious internal error in the current command.",
+          hval->name);
+    send_error (hval, 500);
+    return;
+  }
+
+  /* Send reply. */
+  hval->reply =
+    apr_psprintf (hval->rpool, PROTOCOL_VERSION " %03d %s" CRLF,
+                 code, msg);
+  hval->reply_size = strlen (hval->reply);
+  hval->reply_posn = 0;
+  hval->state = guest_state_reply;
 }
 
 void
@@ -261,3 +480,79 @@ send_error (struct guest_description *hval, int code)
    */
   hval->penalty++;
 }
+
+static const char *
+string_of_arg_type (enum arg_type t)
+{
+  switch (t) {
+  case arg_type_string: return "string";
+  case arg_type_int:    return "int";
+  case arg_type_bool:   return "bool";
+  }
+  abort ();
+}
+
+int
+get_args (apr_array_header_t *args, const char *argfs, ...)
+{
+  va_list vargs;
+  int i = 0, ret = 0;
+
+  va_start (vargs, argfs);
+
+  while (*argfs) {
+    struct arg arg;
+
+    if (i >= args->nelts) {
+      ret = -1;
+      goto end;
+    }
+
+    arg = APR_ARRAY_IDX (args, i, struct arg);
+
+    switch (*argfs) {
+    case 's':
+      if (arg.type == arg_type_string)
+       * va_arg (vargs, const char **) = arg.u.str;
+      else {
+       ret = -1;
+       goto end;
+      }
+      break;
+
+    case 'i':
+      if (arg.type == arg_type_int)
+       * va_arg (vargs, int *) = arg.u.i;
+      else {
+       ret = -1;
+       goto end;
+      }
+      break;
+
+    case 'b':
+      if (arg.type == arg_type_bool)
+       * va_arg (vargs, int *) = arg.u.i;
+      else {
+       ret = -1;
+       goto end;
+      }
+      break;
+
+    default:
+      error ("get_args: invalid character '%c'", *argfs);
+      ret = -1;
+      goto end;
+    }
+
+    argfs++;
+    i++;
+  }
+
+  if (i < args->nelts)
+    ret = -1;
+
+ end:
+  va_end (vargs);
+
+  return ret;
+}