Source code

Revision control

Copy as Markdown

Other Tools

import json
import base64
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
# This method decodes the JWT and verifies the signature. If a key is provided,
# that will be used for signature verification. Otherwise, the key sent within
# the JWT payload will be used instead.
# This returns a tuple of (decoded_header, decoded_payload, verify_succeeded).
def decode_jwt(token, key=None):
try:
# Decode the header and payload.
header, payload, signature = token.split('.')
decoded_header = decode_base64_json(header)
decoded_payload = decode_base64_json(payload)
# If decoding failed, return nothing.
if not decoded_header or not decoded_payload:
return None, None, False
# If there is a key passed in (for refresh), use that for checking the signature below.
# Otherwise (for registration), use the key sent within the JWT to check the signature.
if key == None:
key = decoded_payload.get('key')
public_key = serialization.load_pem_public_key(jwk_to_pem(key))
# Verifying the signature will throw an exception if it fails.
verify_rs256_signature(header, payload, signature, public_key)
return decoded_header, decoded_payload, True
except Exception:
return None, None, False
def jwk_to_pem(jwk_data):
jwk = json.loads(jwk_data) if isinstance(jwk_data, str) else jwk_data
key_type = jwk.get("kty")
if key_type != "RSA":
raise ValueError(f"Unsupported key type: {key_type}")
n = int.from_bytes(decode_base64url(jwk["n"]), 'big')
e = int.from_bytes(decode_base64url(jwk["e"]), 'big')
public_key = rsa.RSAPublicNumbers(e, n).public_key()
pem_public_key = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return pem_public_key
def verify_rs256_signature(encoded_header, encoded_payload, signature, public_key):
message = (f'{encoded_header}.{encoded_payload}').encode('utf-8')
signature_bytes = decode_base64(signature)
# This will throw an exception if verification fails.
public_key.verify(
signature_bytes,
message,
padding.PKCS1v15(),
hashes.SHA256()
)
def add_base64_padding(encoded_data):
remainder = len(encoded_data) % 4
if remainder > 0:
encoded_data += '=' * (4 - remainder)
return encoded_data
def decode_base64url(encoded_data):
encoded_data = add_base64_padding(encoded_data)
encoded_data = encoded_data.replace("-", "+").replace("_", "/")
return base64.b64decode(encoded_data)
def decode_base64(encoded_data):
encoded_data = add_base64_padding(encoded_data)
return base64.urlsafe_b64decode(encoded_data)
def decode_base64_json(encoded_data):
return json.loads(decode_base64(encoded_data))