38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import itertools, functools
|
|
import array
|
|
import struct
|
|
import torch
|
|
from colorizers import eccv16
|
|
|
|
def run(args):
|
|
model = eccv16(pretrained=True).eval()
|
|
|
|
with open('code.txt', 'w') as f_code, \
|
|
open(args.dst, 'wb') as f_bin:
|
|
total_bytesz = 0
|
|
for name, param in itertools.chain(model.named_parameters(), model.named_buffers()):
|
|
if param.dtype != torch.float: # type check
|
|
print(f'{name} skipped.')
|
|
continue
|
|
f_bin.write(struct.pack(f'{param.numel()}f', *param.flatten()))
|
|
shape = ', '.join(map(str, param.size()))
|
|
sz = functools.reduce(lambda x, y: x * y, param.size())
|
|
f_code.write(f'Tensor {name.replace(".","_")}{{offset, {{{shape}}}}}; offset += {sz};\n')
|
|
elem_sz = param.element_size()
|
|
bytesz = sz * elem_sz
|
|
total_bytesz += bytesz
|
|
print(f'{bytesz} bytes written. (name={name}, shape={{{shape}}})')
|
|
print(f'Total {total_bytesz} bytes written. Check binary size to be sure.')
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('dst', help='Output binary name. (e.g., network.bin)')
|
|
args = parser.parse_args()
|
|
run(args)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|