2023-02-04 17:43:29 +09:00
|
|
|
from PIL import Image
|
|
|
|
import sys
|
|
|
|
import numpy as np
|
2023-02-06 15:07:20 +09:00
|
|
|
from functools import reduce
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch
|
2023-02-04 17:43:29 +09:00
|
|
|
|
|
|
|
|
|
|
|
imgSize = (1,2,640,959)
|
|
|
|
|
|
|
|
def mask_to_image(mask: np.ndarray, mask_values):
|
|
|
|
if isinstance(mask_values[0], list):
|
|
|
|
out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
|
|
|
|
elif mask_values == [0, 1]:
|
|
|
|
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
|
|
|
|
else:
|
|
|
|
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
|
|
|
|
|
|
|
|
if mask.ndim == 3:
|
|
|
|
mask = np.argmax(mask, axis=0)
|
|
|
|
|
|
|
|
for i, v in enumerate(mask_values):
|
|
|
|
out[mask == i] = v
|
|
|
|
|
|
|
|
return Image.fromarray(out)
|
|
|
|
|
|
|
|
if __name__=='__main__':
|
|
|
|
|
2023-02-06 15:07:20 +09:00
|
|
|
sizeMult = reduce(lambda x, y: x*y, imgSize)
|
|
|
|
mask_values = [0,1]
|
|
|
|
imgAll = np.fromfile(sys.argv[1], dtype=np.float32)
|
|
|
|
imgAll = torch.from_numpy(imgAll)
|
2023-02-04 17:43:29 +09:00
|
|
|
|
2023-02-06 15:07:20 +09:00
|
|
|
for idx in range(int(sys.argv[3])):
|
|
|
|
img = imgAll[(sizeMult * idx) : (sizeMult * (idx+1))].reshape(imgSize)
|
|
|
|
img = F.interpolate(img, (640*2, 959*2), mode='bilinear')
|
|
|
|
mask = img.argmax(dim=1)
|
2023-02-04 17:43:29 +09:00
|
|
|
|
2023-02-06 15:07:20 +09:00
|
|
|
mask = mask[0].long().squeeze().numpy()
|
|
|
|
result = mask_to_image(mask, mask_values)
|
|
|
|
result.save(sys.argv[2].split('.')[0]+'_'+str(idx)+'.png')
|