from bits import testBit, clearBit, bytes_to_float, bytes_to_int
from collections import namedtuple
from dataclasses import dataclass
from sys import byteorder
from pathlib import Path
from object_types import types
import argparse
import binascii
import sys
PY33 = sys.version_info >= (3, 3)
PY37 = sys.version_info >= (3, 7)
PY38 = sys.version_info >= (3, 8)
if PY37:
PYC_HEADER_LEN = 16
elif PY33:
PYC_HEADER_LEN = 12
else:
PYC_HEADER_LEN = 8
PyLong_MARSHAL_SHIFT = 15
DEBUG = False
# Flag_ref = namedtuple("Flag_ref", ["byte", "type", "content", "usages"])
Reference = namedtuple("Reference", ["byte", "index"])
@dataclass
class Flag_ref:
byte: int
type: str
content: object
usages: int = 0
class MarshalParser:
def __init__(self, filename):
self.filename = filename
with open(filename, "rb") as fh:
self.bytes = bytes(fh.read())
iterator = enumerate(self.bytes)
# skip pyc header (first n bytes)
if filename.suffix == ".pyc":
for x in range(PYC_HEADER_LEN):
next(iterator)
self.iterator = iterator
def parse(self):
self.references = [] # references to existing objects with FLAG_REF
self.flag_refs = [] # objects with FLAG_REF on
self.output = ""
self.indent = 0
self.read_object()
def record_object_start(self, i, b, ref_id):
"""
Records human readable output of parsing process
"""
byte = binascii.hexlify(b.to_bytes(1, byteorder))
bytestring = b.to_bytes(1, byteorder)
type = types[bytestring]
ref = ""
if ref_id is not None:
ref = f"REF[{ref_id}]"
line = f"n={i}/{hex(i)} byte=({byte}, {bytestring}, " \
f"{bin(b)}) {type} {ref}\n"
if DEBUG:
print(line)
self.output += " " * self.indent + line
def record_object_result(self, result):
"""
Records the result of object parsing with its type
"""
line = f"result={result}, type={type(result)}\n"
self.output += " " * self.indent + line
def record_object_info(self, info):
"""
Records some info about parsed object
"""
line = f"{info}\n"
self.output += " " * self.indent + line
def read_object(self):
"""
Main method for reading/parsing objects and recording references.
Simple objects are parsed directly, complex uses other read_* methods
"""
i, b = next(self.iterator)
ref_id = None
if testBit(b, 7):
b = clearBit(b, 7)
# Save a slot in global references
ref_id = len(self.flag_refs)
self.flag_refs.append(None)
bytestring = b.to_bytes(1, byteorder)
try:
type = types[bytestring]
except KeyError:
print(f"Cannot read/parse byte {b} {bytestring} on possition {i}")
print("Might be error or unsupported TYPE")
print(self.output)
sys.exit(1)
self.record_object_start(i, b, ref_id)
# Increase indentation
self.indent += 2
if type == "TYPE_CODE":
result = self.read_codeobject()
elif type == "TYPE_LONG":
result = self.read_py_long()
elif type in ("TYPE_INT"):
result = self.read_long()
elif type in ("TYPE_STRING", "TYPE_UNICODE",
"TYPE_ASCII", "TYPE_INTERNED"):
result = self.read_string()
elif type == "TYPE_SMALL_TUPLE":
# small tuple — size is only one byte
size = bytes_to_int(self.read_bytes())
self.record_object_info(f"Small tuple size: {size}")
result = []
for x in range(size):
result.append(self.read_object())
result = tuple(result)
elif type in ("TYPE_TUPLE", "TYPE_LIST", "TYPE_SET", "TYPE_FROZENSET"):
# regular tuple, list, set, frozenset
size = self.read_long()
self.record_object_info(f"tuple/list/set size: {size}")
result = []
for x in range(size):
result.append(self.read_object())
if type == "TYPE_TUPLE":
result = tuple(result)
elif type == "TYPE_SET":
result = set(result)
elif type == "TYPE_FROZENSET":
result = frozenset(result)
elif type == "TYPE_NULL":
result = "null"
elif type == "TYPE_NONE":
result = None
elif type == "TYPE_TRUE":
result = True
elif type == "TYPE_FALSE":
result = False
elif type == "TYPE_STOPITER":
result = StopIteration
elif type == "TYPE_ELLIPSIS":
result = ...
elif type in ("TYPE_SHORT_ASCII_INTERNED", "TYPE_SHORT_ASCII"):
result = self.read_string(short=True)
elif type == "TYPE_REF":
index = self.read_long()
self.references.append(Reference(byte=i, index=index))
self.flag_refs[index].usages += 1
result = f"REF to {index}: " + str(self.flag_refs[index])
elif type == "TYPE_BINARY_FLOAT":
result = bytes_to_float(self.read_bytes(count=8))
elif type == "TYPE_BINARY_COMPLEX":
real = bytes_to_float(self.read_bytes(count=8))
imag = bytes_to_float(self.read_bytes(count=8))
result = complex(real, imag)
elif type == "TYPE_DICT":
result = {}
while True:
key = self.read_object()
if key == "null":
break
value = self.read_object()
result[key] = value
# decrease indentation
self.indent -= 2
try:
self.record_object_result(result)
except UnboundLocalError:
raise RuntimeError(
f"Error: type [{type}] is recognized but result is not present"
)
# Save the result to the self.references
if ref_id is not None:
self.flag_refs[ref_id] = Flag_ref(
byte=i, type=type, content=result
)
return result
def read_bytes(self, count=1):
bytes = b""
for x in range(count):
index, byte = next(self.iterator)
byte = byte.to_bytes(1, byteorder)
bytes += byte
return bytes
def read_string(self, size=None, short=False):
if size is None:
if short:
# short == size is stored as one byte
size = bytes_to_int(self.read_bytes())
else:
# non-short == size is stored as long (4 bytes)
size = self.read_long()
bytes = self.read_bytes(size)
return bytes
def read_long(self, signed=False):
bytes = self.read_bytes(count=4)
return bytes_to_int(bytes, signed=signed)
def read_short(self):
b = self.read_bytes(count=2)
x = b[0]
x |= b[1] << 8
# Sign-extension, in case short greater than 16 bits
x |= -(x & 0x8000)
return x
def read_py_long(self):
n = self.read_long(signed=True)
result, shift = 0, 0
for i in range(abs(n)):
result += self.read_short() << shift
shift += PyLong_MARSHAL_SHIFT
return result if n > 0 else -result
def read_codeobject(self):
argcount = self.read_long()
if PY38:
posonlyargcount = self.read_long()
kwonlyargcount = self.read_long()
nlocals = self.read_long()
stacksize = self.read_long()
flags = self.read_long()
code = self.read_object()
consts = self.read_object()
names = self.read_object()
varnames = self.read_object()
freevars = self.read_object()
cellvars = self.read_object()
filename = self.read_object()
name = self.read_object()
firstlineno = self.read_long()
lnotab = self.read_object()
co = dict(locals())
del co["self"] # removed Marshalparser instance from co
return co
def unused_ref_flags(self):
unused = []
for index, flag_ref in enumerate(self.flag_refs):
if flag_ref.usages == 0:
unused.append((index, flag_ref))
return unused
def clear_unused_ref_flags(self, overwrite=False):
# List of flag_refs and references ordered by number of byte in a file
final_list = self.flag_refs + self.references
final_list.sort(key=lambda x: x.byte)
# a map where at a beginning, index in list == number of flag_ref
# but when unused flag is removed:
# - numbers in the list are original numbers of flag_refs
# - indexes of the list are new numbers
flag_ref_map = list(range(len(self.flag_refs)))
# new mutable content
content = bytearray(self.bytes)
for r in final_list:
if isinstance(r, Flag_ref) and r.usages == 0:
# Clear FLAG_REF bit and remove it from map
# all subsequent refs will have lower index in the map
flag_ref_map.remove(self.flag_refs.index(r))
content[r.byte] = clearBit(content[r.byte], 7)
elif isinstance(r, Reference):
# Find a new index of flag_ref after some was removed
new_index = flag_ref_map.index(r.index)
# write new number as 4-byte integer
content[r.byte + 1:r.byte + 5] = new_index.to_bytes(
4, byteorder
)
# Skip writing if there is no difference
if bytes(content) != self.bytes:
if overwrite:
suffix = ""
else:
suffix = ".fixed"
new_name = self.filename.with_suffix(suffix + self.filename.suffix)
with open(new_name, mode="wb") as fh:
fh.write(content)
else:
print("Content is the same, nothing to fix…")
def main():
parser = argparse.ArgumentParser(
description="Marshalparser and fixer for .pyc files"
)
parser.add_argument(
"-p",
"--print",
action="store_true",
dest="print",
default=False,
help="Print human-readable parser output",
)
parser.add_argument(
"-u",
"--unused",
action="store_true",
dest="unused",
default=False,
help="Print unused references",
)
parser.add_argument(
"-f",
"--fix",
action="store_true",
dest="fix",
default=False,
help="Fix references",
)
parser.add_argument(
"-o",
"--overwrite",
action="store_true",
dest="overwrite",
default=False,
help="Overwrite existing pyc file (works with --fix)",
)
parser.add_argument(metavar="files", dest="files", nargs="*")
args = parser.parse_args()
for file in args.files:
parser = MarshalParser(Path(file))
parser.parse()
if args.print:
print(parser.output)
if args.unused:
unused = parser.unused_ref_flags()
if unused:
print("Unused FLAG_REFs:")
print("\n".join([f"{i} - {f}" for i, f in unused]))
if args.fix:
parser.clear_unused_ref_flags(overwrite=args.overwrite)
if __name__ == "__main__":
main()