#!/usr/bin/env python3 import argparse import itertools, functools import array import struct import torch from colorizers import eccv16 def run(args): model = eccv16(pretrained=True).eval() with open('code.txt', 'w') as f_code, \ open(args.dst, 'wb') as f_bin: total_bytesz = 0 for name, param in itertools.chain(model.named_parameters(), model.named_buffers()): if param.dtype != torch.float: # type check print(f'{name} skipped.') continue f_bin.write(struct.pack(f'{param.numel()}f', *param.flatten())) shape = ', '.join(map(str, param.size())) sz = functools.reduce(lambda x, y: x * y, param.size()) f_code.write(f'Tensor {name.replace(".","_")}{{offset, {{{shape}}}}}; offset += {sz};\n') elem_sz = param.element_size() bytesz = sz * elem_sz total_bytesz += bytesz print(f'{bytesz} bytes written. (name={name}, shape={{{shape}}})') print(f'Total {total_bytesz} bytes written. Check binary size to be sure.') def main(): parser = argparse.ArgumentParser() parser.add_argument('dst', help='Output binary name. (e.g., network.bin)') args = parser.parse_args() run(args) if __name__ == '__main__': main()