Don't limit uploads.
[fedora-mingw.git] / openssl / openssl-thread-test.c
1 /* Test program to verify that RSA signing is thread-safe in OpenSSL. */
2
3 #include <assert.h>
4 #include <errno.h>
5 #include <fcntl.h>
6 #include <limits.h>
7 #include <pthread.h>
8 #include <stdio.h>
9 #include <string.h>
10 #include <unistd.h>
11
12 #include <openssl/crypto.h>
13 #include <openssl/err.h>
14 #include <openssl/objects.h>
15 #include <openssl/rand.h>
16 #include <openssl/rsa.h>
17 #include <openssl/md5.h>
18 #include <openssl/ssl.h>
19
20 /* Just assume we want to do engine stuff if we're using 0.9.6b or
21  * higher. This assumption is only valid for versions bundled with RHL. */
22 #if OPENSSL_VERSION_NUMBER  >= 0x0090602fL
23 #include <openssl/engine.h>
24 #define USE_ENGINE
25 #endif
26
27 #define MAX_THREAD_COUNT        10000
28 #define ITERATION_COUNT         10
29 #define MAIN_COUNT              100
30
31 /* OpenSSL requires us to provide thread ID and locking primitives. */
32 pthread_mutex_t *mutex_locks = NULL;
33 static unsigned long
34 thread_id_cb(void)
35 {
36         return (unsigned long) pthread_self();
37 }
38 static void
39 lock_cb(int mode, int n, const char *file, int line)
40 {
41         if (mode & CRYPTO_LOCK) {
42                 pthread_mutex_lock(&mutex_locks[n]);
43         } else {
44                 pthread_mutex_unlock(&mutex_locks[n]);
45         }
46 }
47
48 struct thread_args {
49         RSA *rsa;
50         int digest_type;
51         unsigned char *digest;
52         unsigned int digest_len;
53         unsigned char *signature;
54         unsigned int signature_len;
55         pthread_t main_thread;
56 };
57
58 static int print = 0;
59
60 pthread_mutex_t sign_lock = PTHREAD_MUTEX_INITIALIZER;
61 static int locked_sign = 0;
62 static void SIGN_LOCK() {if (locked_sign) pthread_mutex_lock(&sign_lock);}
63 static void SIGN_UNLOCK() {if (locked_sign) pthread_mutex_unlock(&sign_lock);}
64
65 pthread_mutex_t verify_lock = PTHREAD_MUTEX_INITIALIZER;
66 static int locked_verify = 0;
67 static void VERIFY_LOCK() {if (locked_verify) pthread_mutex_lock(&verify_lock);}
68 static void VERIFY_UNLOCK() {if (locked_verify) pthread_mutex_unlock(&verify_lock);}
69
70 pthread_mutex_t failure_count_lock = PTHREAD_MUTEX_INITIALIZER;
71 long failure_count = 0;
72 static void
73 failure()
74 {
75         pthread_mutex_lock(&failure_count_lock);
76         failure_count++;
77         pthread_mutex_unlock(&failure_count_lock);
78 }
79
80 static void *
81 thread_main(void *argp)
82 {
83         struct thread_args *args = argp;
84         unsigned char *signature;
85         unsigned int signature_len, signature_alloc_len;
86         int ret, i;
87
88         signature_alloc_len = args->signature_len;
89         if (RSA_size(args->rsa) > signature_alloc_len) {
90                 signature_alloc_len = RSA_size(args->rsa);
91         }
92         signature = malloc(signature_alloc_len);
93         if (signature == NULL) {
94                 fprintf(stderr, "Skipping checks in thread %lu -- %s.\n",
95                         (unsigned long) pthread_self(), strerror(errno));
96                 pthread_exit(0);
97                 return NULL;
98         }
99         for (i = 0; i < ITERATION_COUNT; i++) {
100                 signature_len = signature_alloc_len;
101                 SIGN_LOCK();
102                 ret = RSA_check_key(args->rsa);
103                 ERR_print_errors_fp(stdout);
104                 if (ret != 1) {
105                         failure();
106                         break;
107                 }
108                 ret = RSA_sign(args->digest_type,
109                                args->digest,
110                                args->digest_len,
111                                signature, &signature_len,
112                                args->rsa);
113                 SIGN_UNLOCK();
114                 ERR_print_errors_fp(stdout);
115                 if (ret != 1) {
116                         failure();
117                         break;
118                 }
119
120                 VERIFY_LOCK();
121                 ret = RSA_verify(args->digest_type,
122                                  args->digest,
123                                  args->digest_len,
124                                  signature, signature_len,
125                                  args->rsa);
126                 VERIFY_UNLOCK();
127                 if (ret != 1) {
128                         fprintf(stderr,
129                                 "Signature from thread %lu(%d) fails "
130                                 "verification (passed in thread #%lu)!\n",
131                                 (long) pthread_self(), i,
132                                 (long) args->main_thread);
133                         ERR_print_errors_fp(stdout);
134                         failure();
135                         continue;
136                 }
137                 if (print) {
138                         fprintf(stderr, ">%d\n", i);
139                 }
140         }
141         free(signature);
142
143         pthread_exit(0);
144
145         return NULL;
146 }
147
148 unsigned char *
149 xmemdup(unsigned char *s, size_t len)
150 {
151         unsigned char *r;
152         r = malloc(len);
153         if (r == NULL) {
154                 fprintf(stderr, "Out of memory.\n");
155                 ERR_print_errors_fp(stdout);
156                 assert(r != NULL);
157         }
158         memcpy(r, s, len);
159         return r;
160 }
161
162 int
163 main(int argc, char **argv)
164 {
165         RSA *rsa;
166         MD5_CTX md5;
167         int fd, i;
168         pthread_t threads[MAX_THREAD_COUNT];
169         int thread_count = 1000;
170         unsigned char *message, *digest;
171         unsigned int message_len, digest_len;
172         unsigned char *correct_signature;
173         unsigned int correct_siglen, ret;
174         struct thread_args master_args, *args;
175         int sync = 0, seed = 0;
176         int again = 1;
177 #ifdef USE_ENGINE
178         char *engine = NULL;
179         ENGINE *e = NULL;
180 #endif
181
182         pthread_mutex_init(&failure_count_lock, NULL);
183
184         for (i = 1; i < argc; i++) {
185                 if (strcmp(argv[i], "--seed") == 0) {
186                         printf("Seeding PRNG.\n");
187                         seed++;
188                 } else
189                 if (strcmp(argv[i], "--sync") == 0) {
190                         printf("Running synchronized.\n");
191                         sync++;
192                 } else
193                 if ((strcmp(argv[i], "--threads") == 0) && (i < argc - 1)) {
194                         i++;
195                         thread_count = atol(argv[i]);
196                         if (thread_count > MAX_THREAD_COUNT) {
197                                 thread_count = MAX_THREAD_COUNT;
198                         }
199                         printf("Starting %d threads.\n", thread_count);
200                         sync++;
201                 } else
202                 if (strcmp(argv[i], "--sign") == 0) {
203                         printf("Locking signing.\n");
204                         locked_sign++;
205                 } else
206                 if (strcmp(argv[i], "--verify") == 0) {
207                         printf("Locking verifies.\n");
208                         locked_verify++;
209                 } else
210                 if (strcmp(argv[i], "--print") == 0) {
211                         printf("Tracing.\n");
212                         print++;
213 #ifdef USE_ENGINE
214                 } else
215                 if ((strcmp(argv[i], "--engine") == 0) && (i < argc - 1)) {
216                         printf("Using engine \"%s\".\n", argv[i + 1]);
217                         engine = argv[i + 1];
218                         i++;
219 #endif
220                 } else {
221                         printf("Bad argument: %s\n", argv[i]);
222                         return 1;
223                 }
224         }
225
226         /* Get some random data to sign. */
227         fd = open("/dev/urandom", O_RDONLY);
228         if (fd == -1) {
229                 fprintf(stderr, "Error opening /dev/urandom: %s\n",
230                         strerror(errno));
231         }
232
233         if (print) {
234                 fprintf(stderr, "Reading random data.\n");
235         }
236         message = malloc(message_len = 9371);
237         read(fd, message, message_len);
238         close(fd);
239
240         /* Initialize the SSL library and set up thread-safe locking. */
241         ERR_load_crypto_strings();
242         SSL_library_init();
243         mutex_locks = malloc(sizeof(pthread_mutex_t) * CRYPTO_num_locks());
244         for (i = 0; i < CRYPTO_num_locks(); i++) {
245                 pthread_mutex_init(&mutex_locks[i], NULL);
246         }
247         CRYPTO_set_id_callback(thread_id_cb);
248         CRYPTO_set_locking_callback(lock_cb);
249         ERR_print_errors_fp(stdout);
250
251         /* Seed the PRNG if we were asked to do so. */
252         if (seed) {
253                 if (print) {
254                         fprintf(stderr, "Seeding PRNG.\n");
255                 }
256                 RAND_add(message, message_len, message_len);
257                 ERR_print_errors_fp(stdout);
258         }
259
260         /* Turn on a hardware crypto device if asked to do so. */
261 #ifdef USE_ENGINE
262         if (engine) {
263 #if OPENSSL_VERSION_NUMBER  >= 0x0090700fL
264                 ENGINE_load_builtin_engines();
265 #endif
266                 if (print) {
267                         fprintf(stderr, "Initializing \"%s\" engine.\n",
268                                 engine);
269                 }
270                 e = ENGINE_by_id(engine);
271                 ERR_print_errors_fp(stdout);
272                 if (e) {
273                         i = ENGINE_init(e);
274                         ERR_print_errors_fp(stdout);
275                         i = ENGINE_set_default_RSA(e);
276                         ERR_print_errors_fp(stdout);
277                 }
278         }
279 #endif
280
281         /* Compute the digest for the signature. */
282         if (print) {
283                 fprintf(stderr, "Computing digest.\n");
284         }
285         digest = malloc(digest_len = MD5_DIGEST_LENGTH);
286         MD5_Init(&md5);
287         MD5_Update(&md5, message, message_len);
288         MD5_Final(digest, &md5);
289
290         /* Generate a signing key. */
291         if (print) {
292                 fprintf(stderr, "Generating key.\n");
293         }
294         rsa = RSA_generate_key(4096, 3, NULL, NULL);
295         ERR_print_errors_fp(stdout);
296         if (rsa == NULL) {
297                 _exit(1);
298         }
299
300         /* Sign the data. */
301         correct_siglen = RSA_size(rsa);
302         correct_signature = malloc(correct_siglen);
303         for (i = 0; i < MAIN_COUNT; i++) {
304                 if (print) {
305                         fprintf(stderr, "Signing data (%d).\n", i);
306                 }
307                 ret = RSA_check_key(rsa);
308                 ERR_print_errors_fp(stdout);
309                 if (ret != 1) {
310                         failure();
311                 }
312                 correct_siglen = RSA_size(rsa);
313                 ret = RSA_sign(NID_md5, digest, digest_len,
314                                correct_signature, &correct_siglen,
315                                rsa);
316                 ERR_print_errors_fp(stdout);
317                 if (ret != 1) {
318                         _exit(2);
319                 }
320                 if (print) {
321                         fprintf(stderr, "Verifying data (%d).\n", i);
322                 }
323                 ret = RSA_verify(NID_md5, digest, digest_len,
324                                  correct_signature, correct_siglen,
325                                  rsa);
326                 if (ret != 1) {
327                         _exit(2);
328                 }
329         }
330
331         /* Collect up the inforamtion which other threads will need for
332          * comparing their signature results with ours. */
333         master_args.rsa = rsa;
334         master_args.digest_type = NID_md5;
335         master_args.digest = digest;
336         master_args.digest_len = digest_len;
337         master_args.signature = correct_signature;
338         master_args.signature_len = correct_siglen;
339         master_args.main_thread = pthread_self();
340         
341         fprintf(stdout, "Performing %d signatures in each of %d threads "
342                 "(%d, %d).\n", ITERATION_COUNT, thread_count,
343                 digest_len, correct_siglen);
344         fflush(NULL);
345
346         /* Start up all of the threads. */
347         for (i = 0; i < thread_count; i++) {
348                 args = malloc(sizeof(struct thread_args));
349                 args->rsa = RSAPrivateKey_dup(master_args.rsa);
350                 args->digest_type = master_args.digest_type;
351                 args->digest_len = master_args.digest_len;
352                 args->digest = xmemdup(master_args.digest, args->digest_len);
353                 args->signature_len = master_args.signature_len;
354                 args->signature = xmemdup(master_args.signature,
355                                           args->signature_len);
356                 args->main_thread = pthread_self();
357                 ret = pthread_create(&threads[i], NULL, thread_main, args);
358                 while ((ret != 0) && (errno == EAGAIN)) {
359                         ret = pthread_create(&threads[i], NULL,
360                                              thread_main, &args);
361                         fprintf(stderr, "Thread limit hit at %d.\n", i);
362                 }
363                 if (ret != 0) {
364                         fprintf(stderr, "Unable to create thread %d: %s.\n",
365                                 i, strerror(errno));
366                         threads[i] = -1;
367                 } else {
368                         if (sync) {
369                                 ret = pthread_join(threads[i], NULL);
370                                 assert(ret == 0);
371                         }
372                         if (print) {
373                                 fprintf(stderr, "%d\n", i);
374                         }
375                 }
376         }
377
378         /* Wait for all threads to complete.  So long as we can find an
379          * unjoined thread, keep joining threads. */
380         do {
381                 again = 0;
382                 for (i = 0; i < thread_count; i++) {
383                         /* If we have an unterminated thread, join it. */
384                         if (threads[i] != -1) {
385                                 again = 1;
386                                 if (print) {
387                                         fprintf(stderr, "Joining thread %d.\n",
388                                                 i);
389                                 }
390                                 pthread_join(threads[i], NULL);
391                                 threads[i] = -1;
392                                 break;
393                         }
394                 }
395         } while (again == 1);
396
397         fprintf(stderr, "%ld failures\n", failure_count);
398
399         return (failure_count != 0);
400 }