cvsdist 321fa67
/* Test program to verify that RSA signing is thread-safe in OpenSSL. */
cvsdist 321fa67
cvsdist 321fa67
#include <assert.h>
cvsdist 321fa67
#include <errno.h>
cvsdist 321fa67
#include <fcntl.h>
cvsdist 321fa67
#include <limits.h>
cvsdist 321fa67
#include <pthread.h>
cvsdist 321fa67
#include <stdio.h>
cvsdist 321fa67
#include <string.h>
cvsdist 321fa67
#include <unistd.h>
cvsdist 321fa67
cvsdist 321fa67
#include <openssl/crypto.h>
cvsdist 321fa67
#include <openssl/err.h>
cvsdist 321fa67
#include <openssl/objects.h>
cvsdist 321fa67
#include <openssl/rand.h>
cvsdist 321fa67
#include <openssl/rsa.h>
e96bebc
#include <openssl/md5.h>
cvsdist 321fa67
#include <openssl/ssl.h>
cvsdist 321fa67
cvsdist 321fa67
/* Just assume we want to do engine stuff if we're using 0.9.6b or
cvsdist 321fa67
 * higher. This assumption is only valid for versions bundled with RHL. */
cvsdist 321fa67
#if OPENSSL_VERSION_NUMBER  >= 0x0090602fL
cvsdist 321fa67
#include <openssl/engine.h>
cvsdist 321fa67
#define USE_ENGINE
cvsdist 321fa67
#endif
cvsdist 321fa67
cvsdist 321fa67
#define MAX_THREAD_COUNT	10000
cvsdist 321fa67
#define ITERATION_COUNT		10
cvsdist 321fa67
#define MAIN_COUNT		100
cvsdist 321fa67
cvsdist 321fa67
/* OpenSSL requires us to provide thread ID and locking primitives. */
cvsdist 321fa67
pthread_mutex_t *mutex_locks = NULL;
cvsdist 321fa67
static unsigned long
cvsdist 321fa67
thread_id_cb(void)
cvsdist 321fa67
{
cvsdist 321fa67
	return (unsigned long) pthread_self();
cvsdist 321fa67
}
cvsdist 321fa67
static void
cvsdist 321fa67
lock_cb(int mode, int n, const char *file, int line)
cvsdist 321fa67
{
cvsdist 321fa67
	if (mode & CRYPTO_LOCK) {
cvsdist 321fa67
		pthread_mutex_lock(&mutex_locks[n]);
cvsdist 321fa67
	} else {
cvsdist 321fa67
		pthread_mutex_unlock(&mutex_locks[n]);
cvsdist 321fa67
	}
cvsdist 321fa67
}
cvsdist 321fa67
cvsdist 321fa67
struct thread_args {
cvsdist 321fa67
	RSA *rsa;
cvsdist 321fa67
	int digest_type;
cvsdist 321fa67
	unsigned char *digest;
cvsdist 321fa67
	unsigned int digest_len;
cvsdist 321fa67
	unsigned char *signature;
cvsdist 321fa67
	unsigned int signature_len;
cvsdist 321fa67
	pthread_t main_thread;
cvsdist 321fa67
};
cvsdist 321fa67
cvsdist 321fa67
static int print = 0;
cvsdist 321fa67
cvsdist 321fa67
pthread_mutex_t sign_lock = PTHREAD_MUTEX_INITIALIZER;
cvsdist 321fa67
static int locked_sign = 0;
cvsdist 321fa67
static void SIGN_LOCK() {if (locked_sign) pthread_mutex_lock(&sign_lock);}
cvsdist 321fa67
static void SIGN_UNLOCK() {if (locked_sign) pthread_mutex_unlock(&sign_lock);}
cvsdist 321fa67
cvsdist 321fa67
pthread_mutex_t verify_lock = PTHREAD_MUTEX_INITIALIZER;
cvsdist 321fa67
static int locked_verify = 0;
cvsdist 321fa67
static void VERIFY_LOCK() {if (locked_verify) pthread_mutex_lock(&verify_lock);}
cvsdist 321fa67
static void VERIFY_UNLOCK() {if (locked_verify) pthread_mutex_unlock(&verify_lock);}
cvsdist 321fa67
cvsdist 321fa67
pthread_mutex_t failure_count_lock = PTHREAD_MUTEX_INITIALIZER;
cvsdist 321fa67
long failure_count = 0;
cvsdist 321fa67
static void
cvsdist 321fa67
failure()
cvsdist 321fa67
{
cvsdist 321fa67
	pthread_mutex_lock(&failure_count_lock);
cvsdist 321fa67
	failure_count++;
cvsdist 321fa67
	pthread_mutex_unlock(&failure_count_lock);
cvsdist 321fa67
}
cvsdist 321fa67
cvsdist 321fa67
static void *
cvsdist 321fa67
thread_main(void *argp)
cvsdist 321fa67
{
cvsdist 321fa67
	struct thread_args *args = argp;
cvsdist 321fa67
	unsigned char *signature;
cvsdist 321fa67
	unsigned int signature_len, signature_alloc_len;
cvsdist 321fa67
	int ret, i;
cvsdist 321fa67
cvsdist 321fa67
	signature_alloc_len = args->signature_len;
cvsdist 321fa67
	if (RSA_size(args->rsa) > signature_alloc_len) {
cvsdist 321fa67
		signature_alloc_len = RSA_size(args->rsa);
cvsdist 321fa67
	}
cvsdist 321fa67
	signature = malloc(signature_alloc_len);
cvsdist 321fa67
	if (signature == NULL) {
cvsdist 321fa67
		fprintf(stderr, "Skipping checks in thread %lu -- %s.\n",
cvsdist 321fa67
			(unsigned long) pthread_self(), strerror(errno));
cvsdist 321fa67
		pthread_exit(0);
cvsdist 321fa67
		return NULL;
cvsdist 321fa67
	}
cvsdist 321fa67
	for (i = 0; i < ITERATION_COUNT; i++) {
cvsdist 321fa67
		signature_len = signature_alloc_len;
cvsdist 321fa67
		SIGN_LOCK();
cvsdist 321fa67
		ret = RSA_check_key(args->rsa);
cvsdist 321fa67
		ERR_print_errors_fp(stdout);
cvsdist 321fa67
		if (ret != 1) {
cvsdist 321fa67
			failure();
cvsdist 321fa67
			break;
cvsdist 321fa67
		}
cvsdist 321fa67
		ret = RSA_sign(args->digest_type,
cvsdist 321fa67
			       args->digest,
cvsdist 321fa67
			       args->digest_len,
cvsdist 321fa67
			       signature, &signature_len,
cvsdist 321fa67
			       args->rsa);
cvsdist 321fa67
		SIGN_UNLOCK();
cvsdist 321fa67
		ERR_print_errors_fp(stdout);
cvsdist 321fa67
		if (ret != 1) {
cvsdist 321fa67
			failure();
cvsdist 321fa67
			break;
cvsdist 321fa67
		}
cvsdist 321fa67
cvsdist 321fa67
		VERIFY_LOCK();
cvsdist 321fa67
		ret = RSA_verify(args->digest_type,
cvsdist 321fa67
			         args->digest,
cvsdist 321fa67
			         args->digest_len,
cvsdist 321fa67
			         signature, signature_len,
cvsdist 321fa67
			         args->rsa);
cvsdist 321fa67
		VERIFY_UNLOCK();
cvsdist 321fa67
		if (ret != 1) {
cvsdist 321fa67
			fprintf(stderr,
cvsdist 321fa67
				"Signature from thread %lu(%d) fails "
cvsdist 321fa67
				"verification (passed in thread #%lu)!\n",
cvsdist 321fa67
				(long) pthread_self(), i,
cvsdist 321fa67
				(long) args->main_thread);
cvsdist 321fa67
			ERR_print_errors_fp(stdout);
cvsdist 321fa67
			failure();
cvsdist 321fa67
			continue;
cvsdist 321fa67
		}
cvsdist 321fa67
		if (print) {
cvsdist 321fa67
			fprintf(stderr, ">%d\n", i);
cvsdist 321fa67
		}
cvsdist 321fa67
	}
cvsdist 321fa67
	free(signature);
cvsdist 321fa67
cvsdist 321fa67
	pthread_exit(0);
cvsdist 321fa67
cvsdist 321fa67
	return NULL;
cvsdist 321fa67
}
cvsdist 321fa67
cvsdist 321fa67
unsigned char *
cvsdist 321fa67
xmemdup(unsigned char *s, size_t len)
cvsdist 321fa67
{
cvsdist 321fa67
	unsigned char *r;
cvsdist 321fa67
	r = malloc(len);
cvsdist 321fa67
	if (r == NULL) {
cvsdist 321fa67
		fprintf(stderr, "Out of memory.\n");
cvsdist 321fa67
		ERR_print_errors_fp(stdout);
cvsdist 321fa67
		assert(r != NULL);
cvsdist 321fa67
	}
cvsdist 321fa67
	memcpy(r, s, len);
cvsdist 321fa67
	return r;
cvsdist 321fa67
}
cvsdist 321fa67
cvsdist 321fa67
int
cvsdist 321fa67
main(int argc, char **argv)
cvsdist 321fa67
{
cvsdist 321fa67
	RSA *rsa;
cvsdist 321fa67
	MD5_CTX md5;
cvsdist 321fa67
	int fd, i;
cvsdist 321fa67
	pthread_t threads[MAX_THREAD_COUNT];
cvsdist 321fa67
	int thread_count = 1000;
cvsdist 321fa67
	unsigned char *message, *digest;
cvsdist 321fa67
	unsigned int message_len, digest_len;
cvsdist 321fa67
	unsigned char *correct_signature;
cvsdist 321fa67
	unsigned int correct_siglen, ret;
cvsdist 321fa67
	struct thread_args master_args, *args;
cvsdist 321fa67
	int sync = 0, seed = 0;
cvsdist 321fa67
	int again = 1;
cvsdist 321fa67
#ifdef USE_ENGINE
cvsdist 321fa67
	char *engine = NULL;
cvsdist 321fa67
	ENGINE *e = NULL;
cvsdist 321fa67
#endif
cvsdist 321fa67
cvsdist 321fa67
	pthread_mutex_init(&failure_count_lock, NULL);
cvsdist 321fa67
cvsdist 321fa67
	for (i = 1; i < argc; i++) {
cvsdist 321fa67
		if (strcmp(argv[i], "--seed") == 0) {
cvsdist 321fa67
			printf("Seeding PRNG.\n");
cvsdist 321fa67
			seed++;
cvsdist 321fa67
		} else
cvsdist 321fa67
		if (strcmp(argv[i], "--sync") == 0) {
cvsdist 321fa67
			printf("Running synchronized.\n");
cvsdist 321fa67
			sync++;
cvsdist 321fa67
		} else
cvsdist 321fa67
		if ((strcmp(argv[i], "--threads") == 0) && (i < argc - 1)) {
cvsdist 321fa67
			i++;
cvsdist 321fa67
			thread_count = atol(argv[i]);
cvsdist 321fa67
			if (thread_count > MAX_THREAD_COUNT) {
cvsdist 321fa67
				thread_count = MAX_THREAD_COUNT;
cvsdist 321fa67
			}
cvsdist 321fa67
			printf("Starting %d threads.\n", thread_count);
cvsdist 321fa67
			sync++;
cvsdist 321fa67
		} else
cvsdist 321fa67
		if (strcmp(argv[i], "--sign") == 0) {
cvsdist 321fa67
			printf("Locking signing.\n");
cvsdist 321fa67
			locked_sign++;
cvsdist 321fa67
		} else
cvsdist 321fa67
		if (strcmp(argv[i], "--verify") == 0) {
cvsdist 321fa67
			printf("Locking verifies.\n");
cvsdist 321fa67
			locked_verify++;
cvsdist 321fa67
		} else
cvsdist 321fa67
		if (strcmp(argv[i], "--print") == 0) {
cvsdist 321fa67
			printf("Tracing.\n");
cvsdist 321fa67
			print++;
cvsdist 321fa67
#ifdef USE_ENGINE
cvsdist 321fa67
		} else
cvsdist 321fa67
		if ((strcmp(argv[i], "--engine") == 0) && (i < argc - 1)) {
cvsdist 321fa67
			printf("Using engine \"%s\".\n", argv[i + 1]);
cvsdist 321fa67
			engine = argv[i + 1];
cvsdist 321fa67
			i++;
cvsdist 321fa67
#endif
cvsdist 321fa67
		} else {
cvsdist 321fa67
			printf("Bad argument: %s\n", argv[i]);
cvsdist 321fa67
			return 1;
cvsdist 321fa67
		}
cvsdist 321fa67
	}
cvsdist 321fa67
cvsdist 321fa67
	/* Get some random data to sign. */
cvsdist 321fa67
	fd = open("/dev/urandom", O_RDONLY);
cvsdist 321fa67
	if (fd == -1) {
cvsdist 321fa67
		fprintf(stderr, "Error opening /dev/urandom: %s\n",
cvsdist 321fa67
			strerror(errno));
cvsdist 321fa67
	}
cvsdist 321fa67
cvsdist 321fa67
	if (print) {
cvsdist 321fa67
		fprintf(stderr, "Reading random data.\n");
cvsdist 321fa67
	}
cvsdist 321fa67
	message = malloc(message_len = 9371);
cvsdist 321fa67
	read(fd, message, message_len);
cvsdist 321fa67
	close(fd);
cvsdist 321fa67
cvsdist 321fa67
	/* Initialize the SSL library and set up thread-safe locking. */
cvsdist 321fa67
	ERR_load_crypto_strings();
cvsdist 321fa67
	SSL_library_init();
cvsdist 321fa67
	mutex_locks = malloc(sizeof(pthread_mutex_t) * CRYPTO_num_locks());
cvsdist 321fa67
	for (i = 0; i < CRYPTO_num_locks(); i++) {
cvsdist 321fa67
		pthread_mutex_init(&mutex_locks[i], NULL);
cvsdist 321fa67
	}
cvsdist 321fa67
	CRYPTO_set_id_callback(thread_id_cb);
cvsdist 321fa67
	CRYPTO_set_locking_callback(lock_cb);
cvsdist 321fa67
	ERR_print_errors_fp(stdout);
cvsdist 321fa67
cvsdist 321fa67
	/* Seed the PRNG if we were asked to do so. */
cvsdist 321fa67
	if (seed) {
cvsdist 321fa67
		if (print) {
cvsdist 321fa67
			fprintf(stderr, "Seeding PRNG.\n");
cvsdist 321fa67
		}
cvsdist 321fa67
		RAND_add(message, message_len, message_len);
cvsdist 321fa67
		ERR_print_errors_fp(stdout);
cvsdist 321fa67
	}
cvsdist 321fa67
cvsdist 321fa67
	/* Turn on a hardware crypto device if asked to do so. */
cvsdist 321fa67
#ifdef USE_ENGINE
cvsdist 321fa67
	if (engine) {
cvsdist 321fa67
#if OPENSSL_VERSION_NUMBER  >= 0x0090700fL
cvsdist 321fa67
		ENGINE_load_builtin_engines();
cvsdist 321fa67
#endif
cvsdist 321fa67
		if (print) {
cvsdist 321fa67
			fprintf(stderr, "Initializing \"%s\" engine.\n",
cvsdist 321fa67
				engine);
cvsdist 321fa67
		}
cvsdist 321fa67
		e = ENGINE_by_id(engine);
cvsdist 321fa67
		ERR_print_errors_fp(stdout);
cvsdist 321fa67
		if (e) {
cvsdist 321fa67
			i = ENGINE_init(e);
cvsdist 321fa67
			ERR_print_errors_fp(stdout);
cvsdist 321fa67
			i = ENGINE_set_default_RSA(e);
cvsdist 321fa67
			ERR_print_errors_fp(stdout);
cvsdist 321fa67
		}
cvsdist 321fa67
	}
cvsdist 321fa67
#endif
cvsdist 321fa67
cvsdist 321fa67
	/* Compute the digest for the signature. */
cvsdist 321fa67
	if (print) {
cvsdist 321fa67
		fprintf(stderr, "Computing digest.\n");
cvsdist 321fa67
	}
cvsdist 321fa67
	digest = malloc(digest_len = MD5_DIGEST_LENGTH);
cvsdist 321fa67
	MD5_Init(&md5;;
cvsdist 321fa67
	MD5_Update(&md5, message, message_len);
cvsdist 321fa67
	MD5_Final(digest, &md5;;
cvsdist 321fa67
cvsdist 321fa67
	/* Generate a signing key. */
cvsdist 321fa67
	if (print) {
cvsdist 321fa67
		fprintf(stderr, "Generating key.\n");
cvsdist 321fa67
	}
cvsdist 321fa67
	rsa = RSA_generate_key(4096, 3, NULL, NULL);
cvsdist 321fa67
	ERR_print_errors_fp(stdout);
cvsdist 321fa67
	if (rsa == NULL) {
cvsdist 321fa67
		_exit(1);
cvsdist 321fa67
	}
cvsdist 321fa67
cvsdist 321fa67
	/* Sign the data. */
cvsdist 321fa67
	correct_siglen = RSA_size(rsa);
cvsdist 321fa67
	correct_signature = malloc(correct_siglen);
cvsdist 321fa67
	for (i = 0; i < MAIN_COUNT; i++) {
cvsdist 321fa67
		if (print) {
cvsdist 321fa67
			fprintf(stderr, "Signing data (%d).\n", i);
cvsdist 321fa67
		}
cvsdist 321fa67
		ret = RSA_check_key(rsa);
cvsdist 321fa67
		ERR_print_errors_fp(stdout);
cvsdist 321fa67
		if (ret != 1) {
cvsdist 321fa67
			failure();
cvsdist 321fa67
		}
cvsdist 321fa67
		correct_siglen = RSA_size(rsa);
cvsdist 321fa67
		ret = RSA_sign(NID_md5, digest, digest_len,
cvsdist 321fa67
			       correct_signature, &correct_siglen,
cvsdist 321fa67
			       rsa);
cvsdist 321fa67
		ERR_print_errors_fp(stdout);
cvsdist 321fa67
		if (ret != 1) {
cvsdist 321fa67
			_exit(2);
cvsdist 321fa67
		}
cvsdist 321fa67
		if (print) {
cvsdist 321fa67
			fprintf(stderr, "Verifying data (%d).\n", i);
cvsdist 321fa67
		}
cvsdist 321fa67
		ret = RSA_verify(NID_md5, digest, digest_len,
cvsdist 321fa67
			         correct_signature, correct_siglen,
cvsdist 321fa67
			         rsa);
cvsdist 321fa67
		if (ret != 1) {
cvsdist 321fa67
			_exit(2);
cvsdist 321fa67
		}
cvsdist 321fa67
	}
cvsdist 321fa67
cvsdist 321fa67
	/* Collect up the inforamtion which other threads will need for
cvsdist 321fa67
	 * comparing their signature results with ours. */
cvsdist 321fa67
	master_args.rsa = rsa;
cvsdist 321fa67
	master_args.digest_type = NID_md5;
cvsdist 321fa67
	master_args.digest = digest;
cvsdist 321fa67
	master_args.digest_len = digest_len;
cvsdist 321fa67
	master_args.signature = correct_signature;
cvsdist 321fa67
	master_args.signature_len = correct_siglen;
cvsdist 321fa67
	master_args.main_thread = pthread_self();
cvsdist 321fa67
	
cvsdist 321fa67
	fprintf(stdout, "Performing %d signatures in each of %d threads "
cvsdist 321fa67
		"(%d, %d).\n", ITERATION_COUNT, thread_count,
cvsdist 321fa67
		digest_len, correct_siglen);
cvsdist 321fa67
	fflush(NULL);
cvsdist 321fa67
cvsdist 321fa67
	/* Start up all of the threads. */
cvsdist 321fa67
	for (i = 0; i < thread_count; i++) {
cvsdist 321fa67
		args = malloc(sizeof(struct thread_args));
cvsdist 321fa67
		args->rsa = RSAPrivateKey_dup(master_args.rsa);
cvsdist 321fa67
		args->digest_type = master_args.digest_type;
cvsdist 321fa67
		args->digest_len = master_args.digest_len;
cvsdist 321fa67
		args->digest = xmemdup(master_args.digest, args->digest_len);
cvsdist 321fa67
		args->signature_len = master_args.signature_len;
cvsdist 321fa67
		args->signature = xmemdup(master_args.signature,
cvsdist 321fa67
					  args->signature_len);
cvsdist 321fa67
		args->main_thread = pthread_self();
cvsdist 321fa67
		ret = pthread_create(&threads[i], NULL, thread_main, args);
cvsdist 321fa67
		while ((ret != 0) && (errno == EAGAIN)) {
cvsdist 321fa67
			ret = pthread_create(&threads[i], NULL,
cvsdist 321fa67
					     thread_main, &args);
cvsdist 321fa67
			fprintf(stderr, "Thread limit hit at %d.\n", i);
cvsdist 321fa67
		}
cvsdist 321fa67
		if (ret != 0) {
cvsdist 321fa67
			fprintf(stderr, "Unable to create thread %d: %s.\n",
cvsdist 321fa67
				i, strerror(errno));
cvsdist 321fa67
			threads[i] = -1;
cvsdist 321fa67
		} else {
cvsdist 321fa67
			if (sync) {
cvsdist 321fa67
				ret = pthread_join(threads[i], NULL);
cvsdist 321fa67
				assert(ret == 0);
cvsdist 321fa67
			}
cvsdist 321fa67
			if (print) {
cvsdist 321fa67
				fprintf(stderr, "%d\n", i);
cvsdist 321fa67
			}
cvsdist 321fa67
		}
cvsdist 321fa67
	}
cvsdist 321fa67
cvsdist 321fa67
	/* Wait for all threads to complete.  So long as we can find an
cvsdist 321fa67
	 * unjoined thread, keep joining threads. */
cvsdist 321fa67
	do {
cvsdist 321fa67
		again = 0;
cvsdist 321fa67
		for (i = 0; i < thread_count; i++) {
cvsdist 321fa67
			/* If we have an unterminated thread, join it. */
cvsdist 321fa67
			if (threads[i] != -1) {
cvsdist 321fa67
				again = 1;
cvsdist 321fa67
				if (print) {
cvsdist 321fa67
					fprintf(stderr, "Joining thread %d.\n",
cvsdist 321fa67
						i);
cvsdist 321fa67
				}
cvsdist 321fa67
				pthread_join(threads[i], NULL);
cvsdist 321fa67
				threads[i] = -1;
cvsdist 321fa67
				break;
cvsdist 321fa67
			}
cvsdist 321fa67
		}
cvsdist 321fa67
	} while (again == 1);
cvsdist 321fa67
cvsdist 321fa67
	fprintf(stderr, "%ld failures\n", failure_count);
cvsdist 321fa67
cvsdist 321fa67
	return (failure_count != 0);
cvsdist 321fa67
}