38 lines
1.3 KiB
Plaintext
38 lines
1.3 KiB
Plaintext
|
#!/usr/bin/env python3
|
||
|
|
||
|
import argparse
|
||
|
import itertools, functools
|
||
|
import array
|
||
|
import struct
|
||
|
import torch
|
||
|
from colorizers import eccv16
|
||
|
|
||
|
def run(args):
|
||
|
model = eccv16(pretrained=True).eval()
|
||
|
|
||
|
with open('code.txt', 'w') as f_code, \
|
||
|
open(args.dst, 'wb') as f_bin:
|
||
|
total_bytesz = 0
|
||
|
for name, param in itertools.chain(model.named_parameters(), model.named_buffers()):
|
||
|
if param.dtype != torch.float: # type check
|
||
|
print(f'{name} skipped.')
|
||
|
continue
|
||
|
f_bin.write(struct.pack(f'{param.numel()}f', *param.flatten()))
|
||
|
shape = ', '.join(map(str, param.size()))
|
||
|
sz = functools.reduce(lambda x, y: x * y, param.size())
|
||
|
f_code.write(f'Tensor {name.replace(".","_")}{{offset, {{{shape}}}}}; offset += {sz};\n')
|
||
|
elem_sz = param.element_size()
|
||
|
bytesz = sz * elem_sz
|
||
|
total_bytesz += bytesz
|
||
|
print(f'{bytesz} bytes written. (name={name}, shape={{{shape}}})')
|
||
|
print(f'Total {total_bytesz} bytes written. Check binary size to be sure.')
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('dst', help='Output binary name. (e.g., network.bin)')
|
||
|
args = parser.parse_args()
|
||
|
run(args)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|