41 lines
1.2 KiB
Plaintext
41 lines
1.2 KiB
Plaintext
|
#!/usr/bin/env python3
|
||
|
|
||
|
import argparse
|
||
|
import itertools, functools
|
||
|
import array
|
||
|
import os
|
||
|
import struct
|
||
|
import torch
|
||
|
from colorizers import eccv16
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
from skimage import color
|
||
|
|
||
|
def run(args):
|
||
|
model = eccv16(pretrained=True).eval()
|
||
|
|
||
|
with open(args.dst, 'wb') as f:
|
||
|
for img_path in args.src:
|
||
|
print(f'Processing {img_path}...')
|
||
|
img = np.asarray(Image.open(img_path))
|
||
|
if img.ndim == 2:
|
||
|
img = np.tile(img[:,:,None], 3)
|
||
|
img = img[:,:,:3]
|
||
|
img = np.asarray(Image.fromarray(img).resize((256, 256), resample=3))
|
||
|
img = color.rgb2lab(img)[:,:,0]
|
||
|
tin = torch.Tensor(img)[None,None,:,:]
|
||
|
ref = model(tin)
|
||
|
ref = ref[0].detach().numpy().flatten()
|
||
|
data = struct.pack(f'{ref.size}f', *ref)
|
||
|
f.write(data)
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('src', nargs='+', help='Input images. (e.g., imgs/*)')
|
||
|
parser.add_argument('dst', help='Output binary name. (e.g., output.bin)')
|
||
|
args = parser.parse_args()
|
||
|
run(args)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|