Blob Blame History Raw
commit 388ddb783c7bff39feea87dd7e2ac0bfd410d82a
Author: John Dennis <jdennis@redhat.com>
Date:   Thu Mar 16 10:18:57 2017 -0400

    Use jwcrypto instead of jwt
    
    * Replace all imports of jwt with jwcrypto, continue to rely on cryptography
    * Update doc replacing jwt with jwcrypto
    * Update dependencies and requirements
    * Add utility functions get_rsa_public_key() and get_rsa_private_key()
      to obtain cryptography's RSAPublicKey and RSAPrivateKey
    * Add utility normalize_claims() replicate type conversion jwt offered,
      primarially datatime conversion
    * Refactor to use common code and separate concerns
    * Replace jwt's RSA-SHA1 signature with primitives from cryptography
    
    Signed-off-by: John Dennis <jdennis@redhat.com>

diff --git a/oauthlib/common.py b/oauthlib/common.py
index 5d999b2..5fc0a3c 100644
--- a/oauthlib/common.py
+++ b/oauthlib/common.py
@@ -15,6 +15,16 @@ import random
 import re
 import sys
 import time
+from calendar import timegm
+
+from jwcrypto.jwk import JWK
+from jwcrypto.jwt import JWT
+from jwcrypto.common import json_decode
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives.serialization import (
+    load_pem_private_key, load_pem_public_key)
+from cryptography.hazmat.primitives.asymmetric.rsa import (
+    RSAPrivateKey, RSAPublicKey)
 
 try:
     from urllib import quote as _quote
@@ -229,28 +239,69 @@ def generate_token(length=30, chars=UNICODE_ASCII_CHARACTER_SET):
     return ''.join(rand.choice(chars) for x in range(length))
 
 
-def generate_signed_token(private_pem, request):
-    import jwt
+def get_rsa_private_key(data):
+    if isinstance(data, (bytes_type, unicode_type)):
+        if isinstance(data, unicode_type):
+            data = data.encode('ascii')
+        # load_pem_private_key() requires the data to be str or bytes
+        private_key = load_pem_private_key(data, None, default_backend())
+    else:
+        private_key = data
+
+    if not isinstance(private_key, RSAPrivateKey):
+        raise TypeError("Expected RSAPrivateKey, but got %s" %
+                        private_key.__class__.__name__)
+
+    return private_key
+
+
+def get_rsa_public_key(data):
+    if isinstance(data, (bytes_type, unicode_type)):
+        if isinstance(data, unicode_type):
+            data = data.encode('ascii')
+        # load_pem_public_key() requires the data to be str or bytes
+        public_key = load_pem_public_key(data, default_backend())
+    else:
+        public_key = data
+
+    if not isinstance(public_key, RSAPublicKey):
+        raise TypeError("Expected RSAPublicKey, but got %s" %
+                        public_key.__class__.__name__)
+
+    return public_key
 
-    now = datetime.datetime.utcnow()
 
+def normalize_claims(claims):
+    for claim in ['exp', 'iat', 'nbf']:
+        # Convert datetime to a intDate value in known time-format claims
+        if isinstance(claims.get(claim), datetime.datetime):
+            claims[claim] = timegm(claims[claim].utctimetuple())
+    return claims
+
+
+def generate_jwt_assertion(private_key, claims):
+    rsa_private_key = get_rsa_private_key(private_key)
+    jwkey = JWK.from_pyca(rsa_private_key)
+    token = JWT(header={'alg': 'RS256'}, claims=normalize_claims(claims))
+    token.make_signed_token(jwkey)
+    return to_unicode(token.serialize(), "UTF-8")
+
+
+def generate_signed_token(private_pem, request):
+    now = datetime.datetime.utcnow()
     claims = {
         'scope': request.scope,
         'exp': now + datetime.timedelta(seconds=request.expires_in)
     }
-
     claims.update(request.claims)
+    return generate_jwt_assertion(private_pem, claims)
 
-    token = jwt.encode(claims, private_pem, 'RS256')
-    token = to_unicode(token, "UTF-8")
-
-    return token
-
-
-def verify_signed_token(public_pem, token):
-    import jwt
 
-    return jwt.decode(token, public_pem, algorithms=['RS256'])
+def verify_signed_token(public_key, token):
+    rsa_public_key = get_rsa_public_key(public_key)
+    jwkey = JWK.from_pyca(rsa_public_key)
+    signed_token = JWT(key=jwkey, jwt=token)
+    return json_decode(signed_token.claims)
 
 
 def generate_client_id(length=30, chars=CLIENT_ID_CHARACTER_SET):
diff --git a/oauthlib/oauth1/rfc5849/signature.py b/oauthlib/oauth1/rfc5849/signature.py
index 8fa22ba..6956732 100644
--- a/oauthlib/oauth1/rfc5849/signature.py
+++ b/oauthlib/oauth1/rfc5849/signature.py
@@ -31,8 +31,15 @@ try:
 except ImportError:
     import urllib.parse as urlparse
 from . import utils
-from oauthlib.common import urldecode, extract_params, safe_string_equals
-from oauthlib.common import bytes_type, unicode_type
+
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import padding
+from jwcrypto.jwk import JWK
+from jwcrypto.jws import JWS
+
+from oauthlib.common import (urldecode, extract_params, safe_string_equals,
+                             bytes_type, unicode_type,
+                             get_rsa_private_key, get_rsa_public_key)
 
 
 def construct_base_string(http_method, base_string_uri,
@@ -464,17 +471,8 @@ def sign_hmac_sha1(base_string, client_secret, resource_owner_secret):
     # .. _`RFC2045, Section 6.8`: http://tools.ietf.org/html/rfc2045#section-6.8
     return binascii.b2a_base64(signature.digest())[:-1].decode('utf-8')
 
-_jwtrs1 = None
-
-#jwt has some nice pycrypto/cryptography abstractions
-def _jwt_rs1_signing_algorithm():
-    global _jwtrs1
-    if _jwtrs1 is None:
-        import jwt.algorithms as jwtalgo
-        _jwtrs1 = jwtalgo.RSAAlgorithm(jwtalgo.hashes.SHA1)
-    return _jwtrs1
 
-def sign_rsa_sha1(base_string, rsa_private_key):
+def sign_rsa_sha1(base_string, private_key):
     """**RSA-SHA1**
 
     Per `section 3.4.3`_ of the spec.
@@ -493,10 +491,11 @@ def sign_rsa_sha1(base_string, rsa_private_key):
     if isinstance(base_string, unicode_type):
         base_string = base_string.encode('utf-8')
     # TODO: finish RSA documentation
-    alg = _jwt_rs1_signing_algorithm()
-    key = _prepare_key_plus(alg, rsa_private_key)
-    s=alg.sign(base_string, key)
-    return binascii.b2a_base64(s)[:-1].decode('utf-8')
+    key = get_rsa_private_key(private_key)
+    signer = key.signer(padding.PKCS1v15(), hashes.SHA1())
+    signer.update(base_string)
+    signature = signer.finalize()
+    return binascii.b2a_base64(signature)[:-1].decode('utf-8')
 
 
 def sign_rsa_sha1_with_client(base_string, client):
@@ -568,17 +567,13 @@ def verify_hmac_sha1(request, client_secret=None,
                                resource_owner_secret)
     return safe_string_equals(signature, request.signature)
 
-def _prepare_key_plus(alg, keystr):
-    if isinstance(keystr, bytes_type):
-        keystr = keystr.decode('utf-8')
-    return alg.prepare_key(keystr)
 
 def verify_rsa_sha1(request, rsa_public_key):
     """Verify a RSASSA-PKCS #1 v1.5 base64 encoded signature.
 
     Per `section 3.4.3`_ of the spec.
 
-    Note this method requires the jwt and cryptography libraries.
+    Note this method requires the cryptography library.
 
     .. _`section 3.4.3`: http://tools.ietf.org/html/rfc5849#section-3.4.3
 
@@ -595,9 +590,14 @@ def verify_rsa_sha1(request, rsa_public_key):
     message = construct_base_string(request.http_method, uri, norm_params).encode('utf-8')
     sig = binascii.a2b_base64(request.signature.encode('utf-8'))
 
-    alg = _jwt_rs1_signing_algorithm()
-    key = _prepare_key_plus(alg, rsa_public_key)
-    return alg.verify(message, key, sig)
+    key = get_rsa_public_key(rsa_public_key)
+    verifier = key.verifier(sig, padding.PKCS1v15(), hashes.SHA1())
+    verifier.update(message)
+    try:
+        verifier.verify()
+        return True
+    except InvalidSignature:
+        return False
 
 
 def verify_plaintext(request, client_secret=None, resource_owner_secret=None):
diff --git a/oauthlib/oauth2/rfc6749/clients/service_application.py b/oauthlib/oauth2/rfc6749/clients/service_application.py
index 36da98b..d974b1d 100644
--- a/oauthlib/oauth2/rfc6749/clients/service_application.py
+++ b/oauthlib/oauth2/rfc6749/clients/service_application.py
@@ -10,7 +10,7 @@ from __future__ import absolute_import, unicode_literals
 
 import time
 
-from oauthlib.common import to_unicode
+from oauthlib.common import generate_jwt_assertion
 
 from .base import Client
 from ..parameters import prepare_token_request
@@ -139,7 +139,6 @@ class ServiceApplicationClient(Client):
 
         .. _`Section 3.2.1`: http://tools.ietf.org/html/rfc6749#section-3.2.1
         """
-        import jwt
 
         key = private_key or self.private_key
         if not key:
@@ -166,8 +165,7 @@ class ServiceApplicationClient(Client):
 
         claim.update(extra_claims or {})
 
-        assertion = jwt.encode(claim, key, 'RS256')
-        assertion = to_unicode(assertion)
+        assertion = generate_jwt_assertion(key, claim)
 
         return prepare_token_request(self.grant_type,
                                      body=body,
diff --git a/setup.py b/setup.py
index 718db43..77cd6c8 100755
--- a/setup.py
+++ b/setup.py
@@ -18,11 +18,11 @@ def fread(fn):
         return f.read()
 
 if sys.version_info[0] == 3:
-    tests_require = ['nose', 'cryptography', 'pyjwt>=1.0.0', 'blinker']
+    tests_require = ['nose', 'cryptography', 'jwcrypto>=0.3.2', 'blinker']
 else:
-    tests_require = ['nose', 'unittest2', 'cryptography', 'mock', 'pyjwt>=1.0.0', 'blinker']
+    tests_require = ['nose', 'unittest2', 'cryptography', 'mock', 'jwcrypto>=0.3.2', 'blinker']
 rsa_require = ['cryptography']
-signedtoken_require = ['cryptography', 'pyjwt>=1.0.0']
+signedtoken_require = ['cryptography', 'jwcrypto>=0.3.2']
 signals_require = ['blinker']
 
 requires = []
diff --git a/tests/oauth2/rfc6749/clients/test_service_application.py b/tests/oauth2/rfc6749/clients/test_service_application.py
index de57291..842bbcb 100644
--- a/tests/oauth2/rfc6749/clients/test_service_application.py
+++ b/tests/oauth2/rfc6749/clients/test_service_application.py
@@ -4,10 +4,9 @@ from __future__ import absolute_import, unicode_literals
 import os
 from time import time
 
-import jwt
 from mock import patch
 
-from oauthlib.common import Request
+from oauthlib.common import Request, verify_signed_token
 from oauthlib.oauth2 import ServiceApplicationClient
 
 from ....unittest import TestCase
@@ -92,10 +91,10 @@ mfvGGg3xNjTMO7IdrwIDAQAB
         self.assertEqual(r.isnot, 'empty') 
         self.assertEqual(r.grant_type, ServiceApplicationClient.grant_type) 
 
-        claim = jwt.decode(r.assertion, self.public_key, audience=self.audience, algorithms=['RS256'])
+        claim = verify_signed_token(self.public_key, r.assertion)
 
         self.assertEqual(claim['iss'], self.issuer)
-        # audience verification is handled during decode now
+        self.assertEqual(claim['aud'], self.audience)
         self.assertEqual(claim['sub'], self.subject)
         self.assertEqual(claim['iat'], int(t.return_value))
 
diff --git a/tests/oauth2/rfc6749/test_server.py b/tests/oauth2/rfc6749/test_server.py
index aff0d84..850d73e 100644
--- a/tests/oauth2/rfc6749/test_server.py
+++ b/tests/oauth2/rfc6749/test_server.py
@@ -2,7 +2,6 @@
 from __future__ import absolute_import, unicode_literals
 from ...unittest import TestCase
 import json
-import jwt
 import mock
 
 from oauthlib import common