chundoong-lab-ta/SamsungDS22/project/project_skeleton/misc/net2bin

38 lines
1.3 KiB
Plaintext
Raw Normal View History

2022-09-29 18:01:45 +09:00
#!/usr/bin/env python3
import argparse
import itertools, functools
import array
import struct
import torch
from colorizers import eccv16
def run(args):
model = eccv16(pretrained=True).eval()
with open('code.txt', 'w') as f_code, \
open(args.dst, 'wb') as f_bin:
total_bytesz = 0
for name, param in itertools.chain(model.named_parameters(), model.named_buffers()):
if param.dtype != torch.float: # type check
print(f'{name} skipped.')
continue
f_bin.write(struct.pack(f'{param.numel()}f', *param.flatten()))
shape = ', '.join(map(str, param.size()))
sz = functools.reduce(lambda x, y: x * y, param.size())
f_code.write(f'Tensor {name.replace(".","_")}{{offset, {{{shape}}}}}; offset += {sz};\n')
elem_sz = param.element_size()
bytesz = sz * elem_sz
total_bytesz += bytesz
print(f'{bytesz} bytes written. (name={name}, shape={{{shape}}})')
print(f'Total {total_bytesz} bytes written. Check binary size to be sure.')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('dst', help='Output binary name. (e.g., network.bin)')
args = parser.parse_args()
run(args)
if __name__ == '__main__':
main()