Source code for acme.jose.jws

"""JOSE Web Signature."""
import argparse
import base64
import sys

import OpenSSL
import six

from acme.jose import b64
from acme.jose import errors
from acme.jose import json_util
from acme.jose import jwa
from acme.jose import jwk
from acme.jose import util


[docs]class MediaType(object): """MediaType field encoder/decoder.""" PREFIX = 'application/' """MIME Media Type and Content Type prefix.""" @classmethod
[docs] def decode(cls, value): """Decoder.""" # 4.1.10 if '/' not in value: if ';' in value: raise errors.DeserializationError('Unexpected semi-colon') return cls.PREFIX + value return value
@classmethod
[docs] def encode(cls, value): """Encoder.""" # 4.1.10 if ';' not in value: assert value.startswith(cls.PREFIX) return value[len(cls.PREFIX):] return value
[docs]class Signature(json_util.JSONObjectWithFields): """JWS Signature. :ivar combined: Combined Header (protected and unprotected, :class:`Header`). :ivar unicode protected: JWS protected header (Jose Base-64 decoded). :ivar header: JWS Unprotected Header (:class:`Header`). :ivar str signature: The signature. """ header_cls = Header __slots__ = ('combined',) protected = json_util.Field('protected', omitempty=True, default='') header = json_util.Field( 'header', omitempty=True, default=header_cls(), decoder=header_cls.from_json) signature = json_util.Field( 'signature', decoder=json_util.decode_b64jose, encoder=json_util.encode_b64jose) @protected.encoder def protected(value): # pylint: disable=missing-docstring,no-self-argument # wrong type guess (Signature, not bytes) | pylint: disable=no-member return json_util.encode_b64jose(value.encode('utf-8')) @protected.decoder def protected(value): # pylint: disable=missing-docstring,no-self-argument return json_util.decode_b64jose(value).decode('utf-8') def __init__(self, **kwargs): if 'combined' not in kwargs: kwargs = self._with_combined(kwargs) super(Signature, self).__init__(**kwargs) assert self.combined.alg is not None @classmethod def _with_combined(cls, kwargs): assert 'combined' not in kwargs header = kwargs.get('header', cls._fields['header'].default) protected = kwargs.get('protected', cls._fields['protected'].default) if protected: combined = header + cls.header_cls.json_loads(protected) else: combined = header kwargs['combined'] = combined return kwargs @classmethod def _msg(cls, protected, payload): return (b64.b64encode(protected.encode('utf-8')) + b'.' + b64.b64encode(payload))
[docs] def verify(self, payload, key=None): """Verify. :param JWK key: Key used for verification. """ key = self.combined.find_key() if key is None else key return self.combined.alg.verify( key=key.key, sig=self.signature, msg=self._msg(self.protected, payload))
@classmethod
[docs] def sign(cls, payload, key, alg, include_jwk=True, protect=frozenset(), **kwargs): """Sign. :param JWK key: Key for signature. """ assert isinstance(key, alg.kty) header_params = kwargs header_params['alg'] = alg if include_jwk: header_params['jwk'] = key.public_key() assert set(header_params).issubset(cls.header_cls._fields) assert protect.issubset(cls.header_cls._fields) protected_params = {} for header in protect: protected_params[header] = header_params.pop(header) if protected_params: # pylint: disable=star-args protected = cls.header_cls(**protected_params).json_dumps() else: protected = '' header = cls.header_cls(**header_params) # pylint: disable=star-args signature = alg.sign(key.key, cls._msg(protected, payload)) return cls(protected=protected, header=header, signature=signature)
def fields_to_partial_json(self): fields = super(Signature, self).fields_to_partial_json() if not fields['header'].not_omitted(): del fields['header'] return fields @classmethod def fields_from_json(cls, jobj): fields = super(Signature, cls).fields_from_json(jobj) fields_with_combined = cls._with_combined(fields) if 'alg' not in fields_with_combined['combined'].not_omitted(): raise errors.DeserializationError('alg not present') return fields_with_combined
[docs]class JWS(json_util.JSONObjectWithFields): """JSON Web Signature. :ivar str payload: JWS Payload. :ivar str signature: JWS Signatures. """ __slots__ = ('payload', 'signatures') signature_cls = Signature
[docs] def verify(self, key=None): """Verify.""" return all(sig.verify(self.payload, key) for sig in self.signatures)
@classmethod
[docs] def sign(cls, payload, **kwargs): """Sign.""" return cls(payload=payload, signatures=( cls.signature_cls.sign(payload=payload, **kwargs),))
@property def signature(self): """Get a singleton signature. :rtype: `signature_cls` """ assert len(self.signatures) == 1 return self.signatures[0]
[docs] def to_compact(self): """Compact serialization. :rtype: bytes """ assert len(self.signatures) == 1 assert 'alg' not in self.signature.header.not_omitted() # ... it must be in protected return ( b64.b64encode(self.signature.protected.encode('utf-8')) + b'.' + b64.b64encode(self.payload) + b'.' + b64.b64encode(self.signature.signature))
@classmethod
[docs] def from_compact(cls, compact): """Compact deserialization. :param bytes compact: """ try: protected, payload, signature = compact.split(b'.') except ValueError: raise errors.DeserializationError( 'Compact JWS serialization should comprise of exactly' ' 3 dot-separated components') sig = cls.signature_cls( protected=b64.b64decode(protected).decode('utf-8'), signature=b64.b64decode(signature)) return cls(payload=b64.b64decode(payload), signatures=(sig,))
def to_partial_json(self, flat=True): # pylint: disable=arguments-differ assert self.signatures payload = json_util.encode_b64jose(self.payload) if flat and len(self.signatures) == 1: ret = self.signatures[0].to_partial_json() ret['payload'] = payload return ret else: return { 'payload': payload, 'signatures': self.signatures, } @classmethod def from_json(cls, jobj): if 'signature' in jobj and 'signatures' in jobj: raise errors.DeserializationError('Flat mixed with non-flat') elif 'signature' in jobj: # flat return cls(payload=json_util.decode_b64jose(jobj.pop('payload')), signatures=(cls.signature_cls.from_json(jobj),)) else: return cls(payload=json_util.decode_b64jose(jobj['payload']), signatures=tuple(cls.signature_cls.from_json(sig) for sig in jobj['signatures']))
[docs]class CLI(object): """JWS CLI.""" @classmethod
[docs] def sign(cls, args): """Sign.""" key = args.alg.kty.load(args.key.read()) args.key.close() if args.protect is None: args.protect = [] if args.compact: args.protect.append('alg') sig = JWS.sign(payload=sys.stdin.read().encode(), key=key, alg=args.alg, protect=set(args.protect)) if args.compact: six.print_(sig.to_compact().decode('utf-8')) else: # JSON six.print_(sig.json_dumps_pretty())
@classmethod
[docs] def verify(cls, args): """Verify.""" if args.compact: sig = JWS.from_compact(sys.stdin.read().encode()) else: # JSON try: sig = JWS.json_loads(sys.stdin.read()) except errors.Error as error: six.print_(error) return -1 if args.key is not None: assert args.kty is not None key = args.kty.load(args.key.read()).public_key() args.key.close() else: key = None sys.stdout.write(sig.payload) return not sig.verify(key=key)
@classmethod def _alg_type(cls, arg): return jwa.JWASignature.from_json(arg) @classmethod def _header_type(cls, arg): assert arg in Signature.header_cls._fields return arg @classmethod def _kty_type(cls, arg): assert arg in jwk.JWK.TYPES return jwk.JWK.TYPES[arg] @classmethod
[docs] def run(cls, args=sys.argv[1:]): """Parse arguments and sign/verify.""" parser = argparse.ArgumentParser() parser.add_argument('--compact', action='store_true') subparsers = parser.add_subparsers() parser_sign = subparsers.add_parser('sign') parser_sign.set_defaults(func=cls.sign) parser_sign.add_argument( '-k', '--key', type=argparse.FileType('rb'), required=True) parser_sign.add_argument( '-a', '--alg', type=cls._alg_type, default=jwa.RS256) parser_sign.add_argument( '-p', '--protect', action='append', type=cls._header_type) parser_verify = subparsers.add_parser('verify') parser_verify.set_defaults(func=cls.verify) parser_verify.add_argument( '-k', '--key', type=argparse.FileType('rb'), required=False) parser_verify.add_argument( '--kty', type=cls._kty_type, required=False) parsed = parser.parse_args(args) return parsed.func(parsed)
if __name__ == '__main__': exit(CLI.run()) # pragma: no cover