Hostinfo day 4: Implement command processing code.
[virt-hostinfo.git] / hostinfod / commands.c
1 /* virt-hostinfo
2  * Copyright (C) 2009 Red Hat Inc.
3  *
4  * This program is free software; you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation; either version 2 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
17  */
18
19 /* This code parses and executes the commands sent by guests.  It
20  * is therefore particularly security sensitive.  The protocol is
21  * documented in hostinfo-protocol(5).
22  */
23
24 #ifdef HAVE_CONFIG_H
25 #include <config.h>
26 #endif
27
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <stdarg.h>
31 #include <assert.h>
32 #include <ctype.h>
33 #include <unistd.h>
34 #include <time.h>
35
36 #include <apr_general.h>
37 #include <apr_pools.h>
38 #include <apr_hash.h>
39 #include <apr_strings.h>
40
41 #include "hostinfod.h"
42
43 #define PROTOCOL_VERSION "1.0"
44 #define CRLF "\r\n"
45
46 apr_hash_t *commands = NULL;
47
48 static const char *string_of_arg_type (enum arg_type);
49 static int not_printable (const char *str);
50 static int contains_crlf (const char *str);
51 static char *parse_c_string (struct guest_description *hval, const char *str, size_t *posn);
52 static int parse_long (struct guest_description *hval, const char *str, size_t *posn, long *ret);
53
54 void
55 execute_command (const struct timespec *now,
56                  struct guest_description *hval, const char *command)
57 {
58   size_t len, i, j;
59   int neg;
60   char *cmd;
61   apr_array_header_t *args;
62   struct arg arg;
63   command_fn fn;
64
65   debug ("%s: %s", hval->name, command);
66
67   /* Create a new pool for allocation during the lifetime of the
68    * request/response.  **NB** the hval->reply field is allocated
69    * from this pool, which is why it gets nulled below.
70    */
71   if (hval->rpool)
72     apr_pool_destroy (hval->rpool);
73   hval->rpool = NULL;
74   hval->reply = NULL;
75   hval->reply_size = hval->reply_posn = 0;
76   apr_pool_create (&hval->rpool, hval->pool);
77
78   /* Split up the command.  Commands have a very narrowly-defined
79    * format, and we reject any malformed commands with a 400 error.
80    *
81    * NB. A lot of the code below assumes 7 bit, printable ASCII,
82    * and this is only safe because of the call to 'not_printable'
83    * here.
84    */
85   len = strlen (command);
86   if (len < 1 || len > 4094 || not_printable (command)) {
87     warning ("%s: command too short, too long or contained unprintable chars (len = %d)",
88              hval->name, len);
89     send_error (hval, 400);
90     return;
91   }
92
93   /* Command is alphanumeric string, non-zero length. */
94   i = strspn (command,
95               "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
96               "abcdefghijklmnopqrstuvwxyz"
97               "0123456789");
98   if (i < 1) {
99     warning ("%s: no command part in request", hval->name);
100     send_error (hval, 400);
101     return;
102   }
103
104   cmd = apr_pstrmemdup (hval->rpool, command, i);
105   for (j = 0; j < i; ++j)
106     cmd[j] = tolower (cmd[j]);
107
108   args = apr_array_make (hval->rpool, 0, sizeof (struct arg));
109
110   while (command[i]) {  /* command[i] is the space before the next arg */
111     if (command[i] != ' ') {
112       warning ("%s: there must be a single space between command and each argument",
113                hval->name);
114       send_error (hval, 400);
115       return;
116     }
117
118     neg = 0;
119
120     i++;
121     switch (command[i]) {
122     case '\0':
123       warning ("%s: trailing space after command", hval->name);
124       send_error (hval, 400);
125       return;
126
127     case '"':                   /* string literal */
128       arg.type = arg_type_string;
129       arg.u.str = parse_c_string (hval, command, &i);
130       if (!arg.u.str) {
131         send_error (hval, 400);
132         return;
133       }
134       break;
135
136     case '-':                   /* integer literal */
137       neg = 1;
138     case '0':
139     case '1':
140     case '2':
141     case '3':
142     case '4':
143     case '5':
144     case '6':
145     case '7':
146     case '8':
147     case '9':
148       arg.type = arg_type_int;
149       if (parse_long (hval, command, &i, &arg.u.i) == -1) {
150         send_error (hval, 400);
151         return;
152       }
153       break;
154
155     case 't':                   /* boolean */
156     case 'f':
157     case 'T':
158     case 'F':
159       arg.type = arg_type_bool;
160       if (strncasecmp (&command[i], "true", 4) == 0) {
161         arg.u.i = 1;
162         i += 4;
163       } else if (strncasecmp (&command[i], "false", 5) == 0) {
164         arg.u.i = 0;
165         i += 5;
166       } else
167         goto unknown_arg;
168       break;
169
170     default:
171     unknown_arg:
172       warning ("%s: unknown or malformed argument starting at position %d ('%c')",
173                hval->name, i, command[i]);
174       send_error (hval, 400);
175       return;
176     } /* switch */
177
178     APR_ARRAY_PUSH (args, struct arg) = arg;
179   }
180
181   if (verbose) {
182     /* Debug output what the guest sent / what we decoded. */
183     debug ("%s: command '%s' with %d args", hval->name, cmd, args->nelts);
184     for (i = 0; i < args->nelts; ++i) {
185       struct arg arg = APR_ARRAY_IDX (args, i, struct arg);
186
187       switch (arg.type) {
188       case arg_type_bool:
189       case arg_type_int:
190         debug ("%s: arg %d : %s = %ld",
191                hval->name, i, string_of_arg_type (arg.type), arg.u.i);
192         break;
193       case arg_type_string:
194         debug ("%s: arg %d : %s = %s",
195                hval->name, i, string_of_arg_type (arg.type), arg.u.str);
196         break;
197       }
198     }
199   }
200
201   /* Know about this command? */
202   fn = apr_hash_get (commands, cmd, APR_HASH_KEY_STRING);
203   if (!fn) {
204     send_error (hval, 404);
205     return;
206   }
207
208   /* Before dispatching the command, check the command is enabled
209    * and guest is not calling it too frequently.
210    */
211   error ("XXXXXXX enabled check not implemented XXXXXXX");
212   error ("XXXXXXX frequency check not implemented XXXXXXX");
213
214   /* Dispatch the command. */
215   fn (hval, cmd, args);
216 }
217
218 /* All commands must consist only of printable 7 bit ASCII.
219  * NB. Don't use isprint(3).
220  */
221 static int
222 not_printable (const char *str)
223 {
224   int c;
225
226   while ((c = *(unsigned char *)str)) {
227     if (c < 32 || c > 126)
228       return 1;
229     str++;
230   }
231   return 0;
232 }
233
234 static int
235 contains_crlf (const char *str)
236 {
237   int c;
238
239   while ((c = *(unsigned char *)str)) {
240     if (c == '\r' || c == '\n')
241       return 1;
242     str++;
243   }
244   return 0;
245 }
246
247 static int
248 hexval (int c)
249 {
250   switch (c) {
251   case '0': case '1': case '2': case '3': case '4':
252   case '5': case '6': case '7': case '8': case '9':
253     return c - '0';
254   case 'a': case 'b': case 'c': case 'd': case 'e': case 'f':
255     return c - 'a' + 10;
256   case 'A': case 'B': case 'C': case 'D': case 'E': case 'F':
257     return c - 'A' + 10;
258   default:
259     abort ();
260   }
261 }
262
263 static char *
264 parse_c_string (struct guest_description *hval, const char *str, size_t *posn)
265 {
266   apr_array_header_t *r;
267   int c;
268
269   r = apr_array_make (hval->rpool, 0, 1);
270 #define APPEND(c) APR_ARRAY_PUSH (r,char) = (c);
271
272   assert (str[*posn] == '"');
273   (*posn)++;
274
275   while (str[*posn] != '"') {
276     if (str[*posn] == '\\') {   /* Start of \escape sequence. */
277       switch (str[*posn+1]) {
278       case '0': case '1': case '2': case '3':
279         if ((str[*posn+2] >= '0' && str[*posn+2] <= '7') &&
280             (str[*posn+3] >= '0' && str[*posn+3] <= '7')) {
281           c = (str[*posn+1] - '0') * 0100;
282           c += (str[*posn+2] - '0') * 010;
283           c += str[*posn+3] - '0';
284           if (c != 0) {
285             APPEND (c);
286             *posn += 4;
287           } else {
288             warning ("%s: \\0 cannot appear in string literal", hval->name);
289             return NULL;
290           }
291           break;
292         } else {
293           warning ("%s: invalid octal sequence in string literal", hval->name);
294           return NULL;
295         }
296       case 'x':
297         if (isxdigit (str[*posn+2]) && isxdigit (str[*posn+3])) {
298           c = hexval (str[*posn+2]) * 0x10;
299           c += hexval (str[*posn+3]);
300           if (c != 0) {
301             APPEND (c);
302             *posn += 4;
303           } else {
304             warning ("%s: \\0 cannot appear in string literal", hval->name);
305             return NULL;
306           }
307           break;
308         } else {
309           warning ("%s: invalid hex sequence in string literal", hval->name);
310           return NULL;
311         }
312       case 'a':
313         APPEND ('\a');
314         *posn += 2;
315         break;
316       case 'b':
317         APPEND ('\b');
318         *posn += 2;
319         break;
320       case 'f':
321         APPEND ('\f');
322         *posn += 2;
323         break;
324       case 'n':
325         APPEND ('\n');
326         *posn += 2;
327         break;
328       case 'r':
329         APPEND ('\r');
330         *posn += 2;
331         break;
332       case 't':
333         APPEND ('\t');
334         *posn += 2;
335         break;
336       case 'v':
337         APPEND ('\v');
338         *posn += 2;
339         break;
340       case '"':
341         APPEND ('"');
342         *posn += 2;
343         break;
344       case '\\':
345         APPEND ('\\');
346         *posn += 2;
347         break;
348
349       default:
350         warning ("%s: unterminated escape sequence in string literal",
351                  hval->name);
352         return NULL;
353       }
354     } else if (str[*posn] == '\0') { /* Unterminated string literal. */
355       warning ("%s: unterminated string literal in request", hval->name);
356       return NULL;
357     } else {                    /* Ordinary character. */
358       APPEND (str[*posn]);
359       (*posn)++;
360     }
361   }
362
363   /* Finish off the string and return it. */
364   APPEND ('\0');
365   (*posn)++;                    /* Skips over the final quote. */
366
367   return r->elts;
368 }
369
370 static int
371 parse_long (struct guest_description *hval,
372             const char *str, size_t *posn, long *ret)
373 {
374   error ("XXXXXXX parse_long not implemented XXXXXXX");
375   return -1;
376 }
377
378 /* For single line replies. */
379 void
380 send_reply (struct guest_description *hval, int code, const char *fs, ...)
381 {
382   va_list args;
383   char *msg;
384
385   /* All success codes must be 2xx. */
386   assert (code >= 200 && code < 300);
387
388   va_start (args, fs);
389   msg = apr_pvsprintf (hval->rpool, fs, args);
390   va_end (args);
391
392   /* The result string must not contain any CR or LF characters.  If
393    * not it's an internal error in the command, or else the caller has
394    * (somehow) managed to pass a bad string through.
395    */
396   if (contains_crlf (msg)) {
397     error ("%s: send_reply: refusing the send a reply message containing CR/LF characters.  This is a serious internal error in the current command.",
398            hval->name);
399     send_error (hval, 500);
400     return;
401   }
402
403   /* Send reply. */
404   hval->reply =
405     apr_psprintf (hval->rpool, PROTOCOL_VERSION " %03d %s" CRLF,
406                   code, msg);
407   hval->reply_size = strlen (hval->reply);
408   hval->reply_posn = 0;
409   hval->state = guest_state_reply;
410 }
411
412 void
413 send_error (struct guest_description *hval, int code)
414 {
415   const char *msg;
416
417   /* All errors must be 4xx or 5xx. */
418   assert (code >= 400 && code < 600);
419
420   /* NB: If you add a code, update COMMON STATUS CODES section
421    * in hostinfo-protocol.pod.
422    */
423   switch (code) {
424   case 400: msg = "Bad request"; break;
425   case 401: msg = "Command disabled"; break;
426   case 404: msg = "Command not found"; break;
427   case 406: msg = "Too frequent"; break;
428   case 500: msg = "Internal server error"; break;
429   default:  msg = "Unknown error"; break;
430   }
431
432   /* Construct the reply. */
433   hval->reply =
434     apr_psprintf (hval->rpool, PROTOCOL_VERSION " %03d %s" CRLF,
435                   code, msg);
436   hval->reply_size = strlen (hval->reply);
437   hval->reply_posn = 0;
438   hval->state = guest_state_reply;
439
440   /* Penalty is always increased on errors, to ensure the guest
441    * cannot flood us with invalid commands.
442    */
443   hval->penalty++;
444 }
445
446 static const char *
447 string_of_arg_type (enum arg_type t)
448 {
449   switch (t) {
450   case arg_type_string: return "string";
451   case arg_type_int:    return "int";
452   case arg_type_bool:   return "bool";
453   }
454   abort ();
455 }
456
457 int
458 get_args (apr_array_header_t *args, const char *argfs, ...)
459 {
460   va_list vargs;
461   int i = 0, ret = 0;
462
463   va_start (vargs, argfs);
464
465   while (*argfs) {
466     struct arg arg;
467
468     if (i >= args->nelts) {
469       ret = -1;
470       goto end;
471     }
472
473     arg = APR_ARRAY_IDX (args, i, struct arg);
474
475     switch (*argfs) {
476     case 's':
477       if (arg.type == arg_type_string)
478         * va_arg (vargs, const char **) = arg.u.str;
479       else {
480         ret = -1;
481         goto end;
482       }
483       break;
484
485     case 'i':
486       if (arg.type == arg_type_int)
487         * va_arg (vargs, int *) = arg.u.i;
488       else {
489         ret = -1;
490         goto end;
491       }
492       break;
493
494     case 'b':
495       if (arg.type == arg_type_bool)
496         * va_arg (vargs, int *) = arg.u.i;
497       else {
498         ret = -1;
499         goto end;
500       }
501       break;
502
503     default:
504       error ("get_args: invalid character '%c'", *argfs);
505       ret = -1;
506       goto end;
507     }
508
509     argfs++;
510     i++;
511   }
512
513   if (i < args->nelts)
514     ret = -1;
515
516  end:
517   va_end (vargs);
518
519   return ret;
520 }