#!/usr/bin/env python3
"""Code for iterating over and filtering PNG file chunks.

See https://en.wikipedia.org/wiki/Portable_Network_Graphics for a description
of the PNG file format.
"""

import sys
import argparse
import struct


class InvalidPngError(Exception):
    """PNG parsing error exception.
    """

PNG_HEADER = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A])

class PngIterator:
    """Iterates over the chunks in a PNG file/stream.
    """
    def __init__(self, stream):
        self.stream = stream
        header = stream.read(len(PNG_HEADER))
        if PNG_HEADER != header:
            raise InvalidPngError(f"Not a PNG; found {header}")

    def __iter__(self):
        return self

    def __next__(self):
        raw_length = self.stream.read(4)
        if not raw_length: # Nothing more to read
            raise StopIteration
        length = struct.unpack('>L', raw_length)[0]
        chunk_type = self.stream.read(4)
        chunk_data = self.stream.read(length)
        chunk_crc = self.stream.read(4)
        return (raw_length, chunk_type, chunk_data, chunk_crc)


def main(argv=None):
    """Filter chunks from a PNG file.
    """
    if argv is None:
        argv = sys.argv

    parser = argparse.ArgumentParser(description=main.__doc__)
    parser.add_argument('--exclude', action='append', default=[],
                        help="chunk types to remove from the PNG image.")
    parser.add_argument('--verbose', action='store_true',
                        help="list chunks encountered and exclusions")
    parser.add_argument('filename')
    parser.add_argument('target')
    args = parser.parse_args(argv[1:])

    if args.verbose:
        verbose = sys.stderr.write
    else:
        def verbose(_message):
            pass

    if args.filename == '-':
        source_file = sys.stdin.buffer # binary data
    else:
        source_file = open(args.filename, 'rb')

    # Ensure the input file is a valid PNG before we create the target file.
    try:
        png_chunks = PngIterator(source_file)
    except InvalidPngError as error:
        parser.error(f"Bad input: {error}")

    if args.target == '-':
        target_file = sys.stdout.buffer # binary data
    else:
        target_file = open(args.target, 'wb')

    verbose(f"Excluding {', '.join(sorted(args.exclude))} chunks\n")
    excludes = set(bytes(x, 'utf8') for x in args.exclude)

    target_file.write(PNG_HEADER)
    for raw_length, chunk_type, chunk_data, chunk_crc in png_chunks:
        verbose(f"Found {chunk_type.decode('utf8')} chunk\n")
        if chunk_type in excludes:
            verbose(f"Excluding {chunk_type.decode('utf8')} chunk\n")
        else:
            target_file.write(raw_length + chunk_type + chunk_data + chunk_crc)

    return 0


if __name__ == '__main__':
    sys.exit(main())
