2018-03-13 06:35:28 +06:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
2020-06-16 07:19:15 +06:00
# Pascal implementation by lulzkabulz. Python translation by apprenticenaomi. DeDRM integration by anon. VoucherEnvelope v2/v3 support by apprenticesakuya.
2018-03-13 06:35:28 +06:00
# BinaryIon.pas + DrmIon.pas + IonSymbols.pas
from __future__ import with_statement
import collections
import hashlib
import hmac
import os
import os . path
import struct
try :
from cStringIO import StringIO
except ImportError :
from StringIO import StringIO
from Crypto . Cipher import AES
from Crypto . Util . py3compat import bchr , bord
try :
2020-01-04 19:20:16 +06:00
# lzma library from calibre 4.6.0 or later
import calibre_lzma . lzma1 as calibre_lzma
2018-04-18 14:57:07 +06:00
except ImportError :
2018-03-13 06:35:28 +06:00
calibre_lzma = None
2020-01-04 19:20:16 +06:00
# lzma library from calibre 2.35.0 or later
2018-03-13 06:35:28 +06:00
try :
2020-01-04 19:20:16 +06:00
import lzma . lzma1 as calibre_lzma
2018-04-18 14:57:07 +06:00
except ImportError :
2020-01-04 19:20:16 +06:00
calibre_lzma = None
2018-04-18 14:57:07 +06:00
try :
2020-01-04 19:20:16 +06:00
import lzma
2018-04-18 14:57:07 +06:00
except ImportError :
2020-01-04 19:20:16 +06:00
# Need pip backports.lzma on Python <3.3
try :
from backports import lzma
except ImportError :
# Windows-friendly choice: pylzma wheels
import pylzma as lzma
2018-03-13 06:35:28 +06:00
TID_NULL = 0
TID_BOOLEAN = 1
TID_POSINT = 2
TID_NEGINT = 3
TID_FLOAT = 4
TID_DECIMAL = 5
TID_TIMESTAMP = 6
TID_SYMBOL = 7
TID_STRING = 8
TID_CLOB = 9
TID_BLOB = 0xA
TID_LIST = 0xB
TID_SEXP = 0xC
TID_STRUCT = 0xD
TID_TYPEDECL = 0xE
TID_UNUSED = 0xF
SID_UNKNOWN = - 1
SID_ION = 1
SID_ION_1_0 = 2
SID_ION_SYMBOL_TABLE = 3
SID_NAME = 4
SID_VERSION = 5
SID_IMPORTS = 6
SID_SYMBOLS = 7
SID_MAX_ID = 8
SID_ION_SHARED_SYMBOL_TABLE = 9
SID_ION_1_0_MAX = 10
LEN_IS_VAR_LEN = 0xE
LEN_IS_NULL = 0xF
VERSION_MARKER = b " \x01 \x00 \xEA "
# asserts must always raise exceptions for proper functioning
def _assert ( test , msg = " Exception " ) :
if not test :
raise Exception ( msg )
class SystemSymbols ( object ) :
ION = ' $ion '
ION_1_0 = ' $ion_1_0 '
ION_SYMBOL_TABLE = ' $ion_symbol_table '
NAME = ' name '
VERSION = ' version '
IMPORTS = ' imports '
SYMBOLS = ' symbols '
MAX_ID = ' max_id '
ION_SHARED_SYMBOL_TABLE = ' $ion_shared_symbol_table '
class IonCatalogItem ( object ) :
name = " "
version = 0
symnames = [ ]
def __init__ ( self , name , version , symnames ) :
self . name = name
self . version = version
self . symnames = symnames
class SymbolToken ( object ) :
text = " "
sid = 0
def __init__ ( self , text , sid ) :
if text == " " and sid == 0 :
raise ValueError ( " Symbol token must have Text or SID " )
self . text = text
self . sid = sid
class SymbolTable ( object ) :
table = None
def __init__ ( self ) :
self . table = [ None ] * SID_ION_1_0_MAX
self . table [ SID_ION ] = SystemSymbols . ION
self . table [ SID_ION_1_0 ] = SystemSymbols . ION_1_0
self . table [ SID_ION_SYMBOL_TABLE ] = SystemSymbols . ION_SYMBOL_TABLE
self . table [ SID_NAME ] = SystemSymbols . NAME
self . table [ SID_VERSION ] = SystemSymbols . VERSION
self . table [ SID_IMPORTS ] = SystemSymbols . IMPORTS
self . table [ SID_SYMBOLS ] = SystemSymbols . SYMBOLS
self . table [ SID_MAX_ID ] = SystemSymbols . MAX_ID
self . table [ SID_ION_SHARED_SYMBOL_TABLE ] = SystemSymbols . ION_SHARED_SYMBOL_TABLE
def findbyid ( self , sid ) :
if sid < 1 :
raise ValueError ( " Invalid symbol id " )
if sid < len ( self . table ) :
return self . table [ sid ]
else :
return " "
def import_ ( self , table , maxid ) :
for i in range ( maxid ) :
self . table . append ( table . symnames [ i ] )
def importunknown ( self , name , maxid ) :
for i in range ( maxid ) :
self . table . append ( " %s # %d " % ( name , i + 1 ) )
class ParserState :
Invalid , BeforeField , BeforeTID , BeforeValue , AfterValue , EOF = 1 , 2 , 3 , 4 , 5 , 6
ContainerRec = collections . namedtuple ( " ContainerRec " , " nextpos, tid, remaining " )
class BinaryIonParser ( object ) :
eof = False
state = None
localremaining = 0
needhasnext = False
isinstruct = False
valuetid = 0
valuefieldid = 0
parenttid = 0
valuelen = 0
valueisnull = False
valueistrue = False
value = None
didimports = False
def __init__ ( self , stream ) :
self . annotations = [ ]
self . catalog = [ ]
self . stream = stream
self . initpos = stream . tell ( )
self . reset ( )
self . symbols = SymbolTable ( )
def reset ( self ) :
self . state = ParserState . BeforeTID
self . needhasnext = True
self . localremaining = - 1
self . eof = False
self . isinstruct = False
self . containerstack = [ ]
self . stream . seek ( self . initpos )
def addtocatalog ( self , name , version , symbols ) :
self . catalog . append ( IonCatalogItem ( name , version , symbols ) )
def hasnext ( self ) :
while self . needhasnext and not self . eof :
self . hasnextraw ( )
if len ( self . containerstack ) == 0 and not self . valueisnull :
if self . valuetid == TID_SYMBOL :
if self . value == SID_ION_1_0 :
self . needhasnext = True
elif self . valuetid == TID_STRUCT :
for a in self . annotations :
if a == SID_ION_SYMBOL_TABLE :
self . parsesymboltable ( )
self . needhasnext = True
break
return not self . eof
def hasnextraw ( self ) :
self . clearvalue ( )
while self . valuetid == - 1 and not self . eof :
self . needhasnext = False
if self . state == ParserState . BeforeField :
_assert ( self . valuefieldid == SID_UNKNOWN )
self . valuefieldid = self . readfieldid ( )
if self . valuefieldid != SID_UNKNOWN :
self . state = ParserState . BeforeTID
else :
self . eof = True
elif self . state == ParserState . BeforeTID :
self . state = ParserState . BeforeValue
self . valuetid = self . readtypeid ( )
if self . valuetid == - 1 :
self . state = ParserState . EOF
self . eof = True
break
if self . valuetid == TID_TYPEDECL :
if self . valuelen == 0 :
self . checkversionmarker ( )
else :
self . loadannotations ( )
elif self . state == ParserState . BeforeValue :
self . skip ( self . valuelen )
self . state = ParserState . AfterValue
elif self . state == ParserState . AfterValue :
if self . isinstruct :
self . state = ParserState . BeforeField
else :
self . state = ParserState . BeforeTID
else :
_assert ( self . state == ParserState . EOF )
def next ( self ) :
if self . hasnext ( ) :
self . needhasnext = True
return self . valuetid
else :
return - 1
def push ( self , typeid , nextposition , nextremaining ) :
self . containerstack . append ( ContainerRec ( nextpos = nextposition , tid = typeid , remaining = nextremaining ) )
def stepin ( self ) :
_assert ( self . valuetid in [ TID_STRUCT , TID_LIST , TID_SEXP ] and not self . eof ,
" valuetid= %s eof= %s " % ( self . valuetid , self . eof ) )
_assert ( ( not self . valueisnull or self . state == ParserState . AfterValue ) and
( self . valueisnull or self . state == ParserState . BeforeValue ) )
nextrem = self . localremaining
if nextrem != - 1 :
nextrem - = self . valuelen
if nextrem < 0 :
nextrem = 0
self . push ( self . parenttid , self . stream . tell ( ) + self . valuelen , nextrem )
self . isinstruct = ( self . valuetid == TID_STRUCT )
if self . isinstruct :
self . state = ParserState . BeforeField
else :
self . state = ParserState . BeforeTID
self . localremaining = self . valuelen
self . parenttid = self . valuetid
self . clearvalue ( )
self . needhasnext = True
def stepout ( self ) :
rec = self . containerstack . pop ( )
self . eof = False
self . parenttid = rec . tid
if self . parenttid == TID_STRUCT :
self . isinstruct = True
self . state = ParserState . BeforeField
else :
self . isinstruct = False
self . state = ParserState . BeforeTID
self . needhasnext = True
self . clearvalue ( )
curpos = self . stream . tell ( )
if rec . nextpos > curpos :
self . skip ( rec . nextpos - curpos )
else :
_assert ( rec . nextpos == curpos )
self . localremaining = rec . remaining
def read ( self , count = 1 ) :
if self . localremaining != - 1 :
self . localremaining - = count
_assert ( self . localremaining > = 0 )
result = self . stream . read ( count )
if len ( result ) == 0 :
raise EOFError ( )
return result
def readfieldid ( self ) :
if self . localremaining != - 1 and self . localremaining < 1 :
return - 1
try :
return self . readvaruint ( )
except EOFError :
return - 1
def readtypeid ( self ) :
if self . localremaining != - 1 :
if self . localremaining < 1 :
return - 1
self . localremaining - = 1
b = self . stream . read ( 1 )
if len ( b ) < 1 :
return - 1
b = bord ( b )
result = b >> 4
ln = b & 0xF
if ln == LEN_IS_VAR_LEN :
ln = self . readvaruint ( )
elif ln == LEN_IS_NULL :
ln = 0
self . state = ParserState . AfterValue
elif result == TID_NULL :
# Must have LEN_IS_NULL
_assert ( False )
elif result == TID_BOOLEAN :
_assert ( ln < = 1 )
self . valueistrue = ( ln == 1 )
ln = 0
self . state = ParserState . AfterValue
elif result == TID_STRUCT :
if ln == 1 :
ln = self . readvaruint ( )
self . valuelen = ln
return result
def readvarint ( self ) :
b = bord ( self . read ( ) )
negative = ( ( b & 0x40 ) != 0 )
result = ( b & 0x3F )
i = 0
while ( b & 0x80 ) == 0 and i < 4 :
b = bord ( self . read ( ) )
result = ( result << 7 ) | ( b & 0x7F )
i + = 1
_assert ( i < 4 or ( b & 0x80 ) != 0 , " int overflow " )
if negative :
return - result
return result
def readvaruint ( self ) :
b = bord ( self . read ( ) )
result = ( b & 0x7F )
i = 0
while ( b & 0x80 ) == 0 and i < 4 :
b = bord ( self . read ( ) )
result = ( result << 7 ) | ( b & 0x7F )
i + = 1
_assert ( i < 4 or ( b & 0x80 ) != 0 , " int overflow " )
return result
def readdecimal ( self ) :
if self . valuelen == 0 :
return 0.
rem = self . localremaining - self . valuelen
self . localremaining = self . valuelen
exponent = self . readvarint ( )
_assert ( self . localremaining > 0 , " Only exponent in ReadDecimal " )
_assert ( self . localremaining < = 8 , " Decimal overflow " )
signed = False
b = [ bord ( x ) for x in self . read ( self . localremaining ) ]
if ( b [ 0 ] & 0x80 ) != 0 :
b [ 0 ] = b [ 0 ] & 0x7F
signed = True
# Convert variably sized network order integer into 64-bit little endian
j = 0
vb = [ 0 ] * 8
for i in range ( len ( b ) , - 1 , - 1 ) :
vb [ i ] = b [ j ]
j + = 1
v = struct . unpack ( " <Q " , b " " . join ( bchr ( x ) for x in vb ) ) [ 0 ]
result = v * ( 10 * * exponent )
if signed :
result = - result
self . localremaining = rem
return result
def skip ( self , count ) :
if self . localremaining != - 1 :
self . localremaining - = count
if self . localremaining < 0 :
raise EOFError ( )
self . stream . seek ( count , os . SEEK_CUR )
def parsesymboltable ( self ) :
self . next ( ) # shouldn't do anything?
_assert ( self . valuetid == TID_STRUCT )
if self . didimports :
return
self . stepin ( )
fieldtype = self . next ( )
while fieldtype != - 1 :
if not self . valueisnull :
_assert ( self . valuefieldid == SID_IMPORTS , " Unsupported symbol table field id " )
if fieldtype == TID_LIST :
self . gatherimports ( )
fieldtype = self . next ( )
self . stepout ( )
self . didimports = True
def gatherimports ( self ) :
self . stepin ( )
t = self . next ( )
while t != - 1 :
if not self . valueisnull and t == TID_STRUCT :
self . readimport ( )
t = self . next ( )
self . stepout ( )
def readimport ( self ) :
version = - 1
maxid = - 1
name = " "
self . stepin ( )
t = self . next ( )
while t != - 1 :
if not self . valueisnull and self . valuefieldid != SID_UNKNOWN :
if self . valuefieldid == SID_NAME :
name = self . stringvalue ( )
elif self . valuefieldid == SID_VERSION :
version = self . intvalue ( )
elif self . valuefieldid == SID_MAX_ID :
maxid = self . intvalue ( )
t = self . next ( )
self . stepout ( )
if name == " " or name == SystemSymbols . ION :
return
if version < 1 :
version = 1
table = self . findcatalogitem ( name )
if maxid < 0 :
_assert ( table is not None and version == table . version , " Import %s lacks maxid " % name )
maxid = len ( table . symnames )
if table is not None :
self . symbols . import_ ( table , min ( maxid , len ( table . symnames ) ) )
else :
self . symbols . importunknown ( name , maxid )
def intvalue ( self ) :
_assert ( self . valuetid in [ TID_POSINT , TID_NEGINT ] , " Not an int " )
self . preparevalue ( )
return self . value
def stringvalue ( self ) :
_assert ( self . valuetid == TID_STRING , " Not a string " )
if self . valueisnull :
return " "
self . preparevalue ( )
return self . value
def symbolvalue ( self ) :
_assert ( self . valuetid == TID_SYMBOL , " Not a symbol " )
self . preparevalue ( )
result = self . symbols . findbyid ( self . value )
if result == " " :
result = " SYMBOL# %d " % self . value
return result
def lobvalue ( self ) :
_assert ( self . valuetid in [ TID_CLOB , TID_BLOB ] , " Not a LOB type: %s " % self . getfieldname ( ) )
if self . valueisnull :
return None
result = self . read ( self . valuelen )
self . state = ParserState . AfterValue
return result
def decimalvalue ( self ) :
_assert ( self . valuetid == TID_DECIMAL , " Not a decimal " )
self . preparevalue ( )
return self . value
def preparevalue ( self ) :
if self . value is None :
self . loadscalarvalue ( )
def loadscalarvalue ( self ) :
if self . valuetid not in [ TID_NULL , TID_BOOLEAN , TID_POSINT , TID_NEGINT ,
TID_FLOAT , TID_DECIMAL , TID_TIMESTAMP ,
TID_SYMBOL , TID_STRING ] :
return
if self . valueisnull :
self . value = None
return
if self . valuetid == TID_STRING :
self . value = self . read ( self . valuelen ) . decode ( " UTF-8 " )
elif self . valuetid in ( TID_POSINT , TID_NEGINT , TID_SYMBOL ) :
if self . valuelen == 0 :
self . value = 0
else :
_assert ( self . valuelen < = 4 , " int too long: %d " % self . valuelen )
v = 0
for i in range ( self . valuelen - 1 , - 1 , - 1 ) :
v = ( v | ( bord ( self . read ( ) ) << ( i * 8 ) ) )
if self . valuetid == TID_NEGINT :
self . value = - v
else :
self . value = v
elif self . valuetid == TID_DECIMAL :
self . value = self . readdecimal ( )
#else:
# _assert(False, "Unhandled scalar type %d" % self.valuetid)
self . state = ParserState . AfterValue
def clearvalue ( self ) :
self . valuetid = - 1
self . value = None
self . valueisnull = False
self . valuefieldid = SID_UNKNOWN
self . annotations = [ ]
def loadannotations ( self ) :
ln = self . readvaruint ( )
maxpos = self . stream . tell ( ) + ln
while self . stream . tell ( ) < maxpos :
self . annotations . append ( self . readvaruint ( ) )
self . valuetid = self . readtypeid ( )
def checkversionmarker ( self ) :
for i in VERSION_MARKER :
_assert ( self . read ( ) == i , " Unknown version marker " )
self . valuelen = 0
self . valuetid = TID_SYMBOL
self . value = SID_ION_1_0
self . valueisnull = False
self . valuefieldid = SID_UNKNOWN
self . state = ParserState . AfterValue
def findcatalogitem ( self , name ) :
for result in self . catalog :
if result . name == name :
return result
def forceimport ( self , symbols ) :
item = IonCatalogItem ( " Forced " , 1 , symbols )
self . symbols . import_ ( item , len ( symbols ) )
def getfieldname ( self ) :
if self . valuefieldid == SID_UNKNOWN :
return " "
return self . symbols . findbyid ( self . valuefieldid )
def getfieldnamesymbol ( self ) :
return SymbolToken ( self . getfieldname ( ) , self . valuefieldid )
def gettypename ( self ) :
if len ( self . annotations ) == 0 :
return " "
return self . symbols . findbyid ( self . annotations [ 0 ] )
@staticmethod
def printlob ( b ) :
if b is None :
return " null "
result = " "
for i in b :
result + = ( " %02x " % bord ( i ) )
if len ( result ) > 0 :
result = result [ : - 1 ]
return result
def ionwalk ( self , supert , indent , lst ) :
while self . hasnext ( ) :
if supert == TID_STRUCT :
L = self . getfieldname ( ) + " : "
else :
L = " "
t = self . next ( )
if t in [ TID_STRUCT , TID_LIST ] :
if L != " " :
lst . append ( indent + L )
L = self . gettypename ( )
if L != " " :
lst . append ( indent + L + " :: " )
if t == TID_STRUCT :
lst . append ( indent + " { " )
else :
lst . append ( indent + " [ " )
self . stepin ( )
self . ionwalk ( t , indent + " " , lst )
self . stepout ( )
if t == TID_STRUCT :
lst . append ( indent + " } " )
else :
lst . append ( indent + " ] " )
else :
if t == TID_STRING :
L + = ( ' " %s " ' % self . stringvalue ( ) )
elif t in [ TID_CLOB , TID_BLOB ] :
L + = ( " { %s } " % self . printlob ( self . lobvalue ( ) ) )
elif t == TID_POSINT :
L + = str ( self . intvalue ( ) )
elif t == TID_SYMBOL :
tn = self . gettypename ( )
if tn != " " :
tn + = " :: "
L + = tn + self . symbolvalue ( )
elif t == TID_DECIMAL :
L + = str ( self . decimalvalue ( ) )
else :
L + = ( " TID %d " % t )
lst . append ( indent + L )
def print_ ( self , lst ) :
self . reset ( )
self . ionwalk ( - 1 , " " , lst )
SYM_NAMES = [ ' com.amazon.drm.Envelope@1.0 ' ,
' com.amazon.drm.EnvelopeMetadata@1.0 ' , ' size ' , ' page_size ' ,
' encryption_key ' , ' encryption_transformation ' ,
' encryption_voucher ' , ' signing_key ' , ' signing_algorithm ' ,
' signing_voucher ' , ' com.amazon.drm.EncryptedPage@1.0 ' ,
' cipher_text ' , ' cipher_iv ' , ' com.amazon.drm.Signature@1.0 ' ,
' data ' , ' com.amazon.drm.EnvelopeIndexTable@1.0 ' , ' length ' ,
' offset ' , ' algorithm ' , ' encoded ' , ' encryption_algorithm ' ,
' hashing_algorithm ' , ' expires ' , ' format ' , ' id ' ,
' lock_parameters ' , ' strategy ' , ' com.amazon.drm.Key@1.0 ' ,
' com.amazon.drm.KeySet@1.0 ' , ' com.amazon.drm.PIDv3@1.0 ' ,
' com.amazon.drm.PlainTextPage@1.0 ' ,
' com.amazon.drm.PlainText@1.0 ' , ' com.amazon.drm.PrivateKey@1.0 ' ,
' com.amazon.drm.PublicKey@1.0 ' , ' com.amazon.drm.SecretKey@1.0 ' ,
' com.amazon.drm.Voucher@1.0 ' , ' public_key ' , ' private_key ' ,
' com.amazon.drm.KeyPair@1.0 ' , ' com.amazon.drm.ProtectedData@1.0 ' ,
' doctype ' , ' com.amazon.drm.EnvelopeIndexTableOffset@1.0 ' ,
' enddoc ' , ' license_type ' , ' license ' , ' watermark ' , ' key ' , ' value ' ,
' com.amazon.drm.License@1.0 ' , ' category ' , ' metadata ' ,
' categorized_metadata ' , ' com.amazon.drm.CategorizedMetadata@1.0 ' ,
' com.amazon.drm.VoucherEnvelope@1.0 ' , ' mac ' , ' voucher ' ,
' com.amazon.drm.ProtectedData@2.0 ' ,
' com.amazon.drm.Envelope@2.0 ' ,
' com.amazon.drm.EnvelopeMetadata@2.0 ' ,
' com.amazon.drm.EncryptedPage@2.0 ' ,
' com.amazon.drm.PlainText@2.0 ' , ' compression_algorithm ' ,
2020-06-16 07:19:15 +06:00
' com.amazon.drm.Compressed@1.0 ' , ' page_index_table ' ,
' com.amazon.drm.VoucherEnvelope@2.0 ' , ' com.amazon.drm.VoucherEnvelope@3.0 ' ]
2018-03-13 06:35:28 +06:00
def addprottable ( ion ) :
ion . addtocatalog ( " ProtectedData " , 1 , SYM_NAMES )
def pkcs7pad ( msg , blocklen ) :
paddinglen = blocklen - len ( msg ) % blocklen
padding = bchr ( paddinglen ) * paddinglen
return msg + padding
def pkcs7unpad ( msg , blocklen ) :
_assert ( len ( msg ) % blocklen == 0 )
paddinglen = bord ( msg [ - 1 ] )
_assert ( paddinglen > 0 and paddinglen < = blocklen , " Incorrect padding - Wrong key " )
_assert ( msg [ - paddinglen : ] == bchr ( paddinglen ) * paddinglen , " Incorrect padding - Wrong key " )
return msg [ : - paddinglen ]
2020-06-16 07:19:15 +06:00
# every VoucherEnvelope version has a corresponding "word" and magic number, used in obfuscating the shared secret
VOUCHER_VERSION_INFOS = {
2 : [ b ' Antidisestablishmentarianism ' , 5 ] ,
3 : [ b ' Floccinaucinihilipilification ' , 8 ]
}
# obfuscate shared secret according to the VoucherEnvelope version
def obfuscate ( secret , version ) :
if version == 1 : # v1 does not use obfuscation
return secret
params = VOUCHER_VERSION_INFOS [ version ]
word = params [ 0 ]
magic = params [ 1 ]
# extend secret so that its length is divisible by the magic number
if len ( secret ) % magic != 0 :
secret = secret + b ' \x00 ' * ( magic - len ( secret ) % magic )
secret = bytearray ( secret )
obfuscated = bytearray ( len ( secret ) )
wordhash = bytearray ( hashlib . sha256 ( word ) . digest ( ) )
# shuffle secret and xor it with the first half of the word hash
for i in range ( 0 , len ( secret ) ) :
index = i / / ( len ( secret ) / / magic ) + magic * ( i % ( len ( secret ) / / magic ) )
obfuscated [ index ] = secret [ i ] ^ wordhash [ index % 16 ]
return obfuscated
2018-03-13 06:35:28 +06:00
class DrmIonVoucher ( object ) :
envelope = None
2020-06-16 07:19:15 +06:00
version = None
2018-03-13 06:35:28 +06:00
voucher = None
drmkey = None
license_type = " Unknown "
encalgorithm = " "
enctransformation = " "
hashalgorithm = " "
lockparams = None
ciphertext = b " "
cipheriv = b " "
secretkey = b " "
def __init__ ( self , voucherenv , dsn , secret ) :
self . dsn , self . secret = dsn , secret
self . lockparams = [ ]
self . envelope = BinaryIonParser ( voucherenv )
addprottable ( self . envelope )
def decryptvoucher ( self ) :
shared = " PIDv3 " + self . encalgorithm + self . enctransformation + self . hashalgorithm
self . lockparams . sort ( )
for param in self . lockparams :
if param == " ACCOUNT_SECRET " :
shared + = param + self . secret
elif param == " CLIENT_ID " :
shared + = param + self . dsn
else :
_assert ( False , " Unknown lock parameter: %s " % param )
2020-06-16 07:19:15 +06:00
sharedsecret = obfuscate ( shared . encode ( ' ASCII ' ) , self . version )
2018-03-13 06:35:28 +06:00
2020-06-16 07:19:15 +06:00
key = hmac . new ( sharedsecret , " PIDv3 " , digestmod = hashlib . sha256 ) . digest ( )
2018-03-13 06:35:28 +06:00
aes = AES . new ( key [ : 32 ] , AES . MODE_CBC , self . cipheriv [ : 16 ] )
b = aes . decrypt ( self . ciphertext )
b = pkcs7unpad ( b , 16 )
self . drmkey = BinaryIonParser ( StringIO ( b ) )
addprottable ( self . drmkey )
_assert ( self . drmkey . hasnext ( ) and self . drmkey . next ( ) == TID_LIST and self . drmkey . gettypename ( ) == " com.amazon.drm.KeySet@1.0 " ,
" Expected KeySet, got %s " % self . drmkey . gettypename ( ) )
self . drmkey . stepin ( )
while self . drmkey . hasnext ( ) :
self . drmkey . next ( )
if self . drmkey . gettypename ( ) != " com.amazon.drm.SecretKey@1.0 " :
continue
self . drmkey . stepin ( )
while self . drmkey . hasnext ( ) :
self . drmkey . next ( )
if self . drmkey . getfieldname ( ) == " algorithm " :
_assert ( self . drmkey . stringvalue ( ) == " AES " , " Unknown cipher algorithm: %s " % self . drmkey . stringvalue ( ) )
elif self . drmkey . getfieldname ( ) == " format " :
_assert ( self . drmkey . stringvalue ( ) == " RAW " , " Unknown key format: %s " % self . drmkey . stringvalue ( ) )
elif self . drmkey . getfieldname ( ) == " encoded " :
self . secretkey = self . drmkey . lobvalue ( )
self . drmkey . stepout ( )
break
self . drmkey . stepout ( )
def parse ( self ) :
self . envelope . reset ( )
_assert ( self . envelope . hasnext ( ) , " Envelope is empty " )
2020-06-16 07:19:15 +06:00
_assert ( self . envelope . next ( ) == TID_STRUCT and str . startswith ( self . envelope . gettypename ( ) , " com.amazon.drm.VoucherEnvelope@ " ) ,
2018-03-13 06:35:28 +06:00
" Unknown type encountered in envelope, expected VoucherEnvelope " )
2020-06-16 07:19:15 +06:00
self . version = int ( self . envelope . gettypename ( ) . split ( ' @ ' ) [ 1 ] [ : - 2 ] )
2018-03-13 06:35:28 +06:00
self . envelope . stepin ( )
while self . envelope . hasnext ( ) :
self . envelope . next ( )
field = self . envelope . getfieldname ( )
if field == " voucher " :
self . voucher = BinaryIonParser ( StringIO ( self . envelope . lobvalue ( ) ) )
addprottable ( self . voucher )
continue
elif field != " strategy " :
continue
_assert ( self . envelope . gettypename ( ) == " com.amazon.drm.PIDv3@1.0 " , " Unknown strategy: %s " % self . envelope . gettypename ( ) )
self . envelope . stepin ( )
while self . envelope . hasnext ( ) :
self . envelope . next ( )
field = self . envelope . getfieldname ( )
if field == " encryption_algorithm " :
self . encalgorithm = self . envelope . stringvalue ( )
elif field == " encryption_transformation " :
self . enctransformation = self . envelope . stringvalue ( )
elif field == " hashing_algorithm " :
self . hashalgorithm = self . envelope . stringvalue ( )
elif field == " lock_parameters " :
self . envelope . stepin ( )
while self . envelope . hasnext ( ) :
_assert ( self . envelope . next ( ) == TID_STRING , " Expected string list for lock_parameters " )
self . lockparams . append ( self . envelope . stringvalue ( ) )
self . envelope . stepout ( )
self . envelope . stepout ( )
self . parsevoucher ( )
def parsevoucher ( self ) :
_assert ( self . voucher . hasnext ( ) , " Voucher is empty " )
_assert ( self . voucher . next ( ) == TID_STRUCT and self . voucher . gettypename ( ) == " com.amazon.drm.Voucher@1.0 " ,
" Unknown type, expected Voucher " )
self . voucher . stepin ( )
while self . voucher . hasnext ( ) :
self . voucher . next ( )
if self . voucher . getfieldname ( ) == " cipher_iv " :
self . cipheriv = self . voucher . lobvalue ( )
elif self . voucher . getfieldname ( ) == " cipher_text " :
self . ciphertext = self . voucher . lobvalue ( )
elif self . voucher . getfieldname ( ) == " license " :
_assert ( self . voucher . gettypename ( ) == " com.amazon.drm.License@1.0 " ,
" Unknown license: %s " % self . voucher . gettypename ( ) )
self . voucher . stepin ( )
while self . voucher . hasnext ( ) :
self . voucher . next ( )
if self . voucher . getfieldname ( ) == " license_type " :
self . license_type = self . voucher . stringvalue ( )
self . voucher . stepout ( )
def printenvelope ( self , lst ) :
self . envelope . print_ ( lst )
def printkey ( self , lst ) :
if self . voucher is None :
self . parse ( )
if self . drmkey is None :
self . decryptvoucher ( )
self . drmkey . print_ ( lst )
def printvoucher ( self , lst ) :
if self . voucher is None :
self . parse ( )
self . voucher . print_ ( lst )
def getlicensetype ( self ) :
return self . license_type
class DrmIon ( object ) :
ion = None
voucher = None
vouchername = " "
key = b " "
onvoucherrequired = None
def __init__ ( self , ionstream , onvoucherrequired ) :
self . ion = BinaryIonParser ( ionstream )
addprottable ( self . ion )
self . onvoucherrequired = onvoucherrequired
def parse ( self , outpages ) :
self . ion . reset ( )
_assert ( self . ion . hasnext ( ) , " DRMION envelope is empty " )
_assert ( self . ion . next ( ) == TID_SYMBOL and self . ion . gettypename ( ) == " doctype " , " Expected doctype symbol " )
_assert ( self . ion . next ( ) == TID_LIST and self . ion . gettypename ( ) in [ " com.amazon.drm.Envelope@1.0 " , " com.amazon.drm.Envelope@2.0 " ] ,
" Unknown type encountered in DRMION envelope, expected Envelope, got %s " % self . ion . gettypename ( ) )
while True :
if self . ion . gettypename ( ) == " enddoc " :
break
self . ion . stepin ( )
while self . ion . hasnext ( ) :
self . ion . next ( )
if self . ion . gettypename ( ) in [ " com.amazon.drm.EnvelopeMetadata@1.0 " , " com.amazon.drm.EnvelopeMetadata@2.0 " ] :
self . ion . stepin ( )
while self . ion . hasnext ( ) :
self . ion . next ( )
if self . ion . getfieldname ( ) != " encryption_voucher " :
continue
if self . vouchername == " " :
self . vouchername = self . ion . stringvalue ( )
self . voucher = self . onvoucherrequired ( self . vouchername )
self . key = self . voucher . secretkey
_assert ( self . key is not None , " Unable to obtain secret key from voucher " )
else :
_assert ( self . vouchername == self . ion . stringvalue ( ) ,
" Unexpected: Different vouchers required for same file? " )
self . ion . stepout ( )
elif self . ion . gettypename ( ) in [ " com.amazon.drm.EncryptedPage@1.0 " , " com.amazon.drm.EncryptedPage@2.0 " ] :
decompress = False
ct = None
civ = None
self . ion . stepin ( )
while self . ion . hasnext ( ) :
self . ion . next ( )
if self . ion . gettypename ( ) == " com.amazon.drm.Compressed@1.0 " :
decompress = True
if self . ion . getfieldname ( ) == " cipher_text " :
ct = self . ion . lobvalue ( )
elif self . ion . getfieldname ( ) == " cipher_iv " :
civ = self . ion . lobvalue ( )
if ct is not None and civ is not None :
self . processpage ( ct , civ , outpages , decompress )
self . ion . stepout ( )
self . ion . stepout ( )
if not self . ion . hasnext ( ) :
break
self . ion . next ( )
def print_ ( self , lst ) :
self . ion . print_ ( lst )
def processpage ( self , ct , civ , outpages , decompress ) :
aes = AES . new ( self . key [ : 16 ] , AES . MODE_CBC , civ [ : 16 ] )
msg = pkcs7unpad ( aes . decrypt ( ct ) , 16 )
if not decompress :
outpages . write ( msg )
return
_assert ( msg [ 0 ] == b " \x00 " , " LZMA UseFilter not supported " )
if calibre_lzma is not None :
with calibre_lzma . decompress ( msg [ 1 : ] , bufsize = 0x1000000 ) as f :
f . seek ( 0 )
outpages . write ( f . read ( ) )
return
decomp = lzma . LZMADecompressor ( format = lzma . FORMAT_ALONE )
while not decomp . eof :
segment = decomp . decompress ( msg [ 1 : ] )
msg = b " " # Contents were internally buffered after the first call
outpages . write ( segment )