implement up1

This commit is contained in:
sota-junsik 2023-02-05 13:44:57 +00:00
parent a594dc0ec5
commit b2673ed518
3 changed files with 66 additions and 16 deletions

Binary file not shown.

View File

@ -48,7 +48,7 @@ class Down(nn.Module):
#print("param:", list(self.maxpool_conv[1].double_conv[4].state_dict()['bias']))
x1 = self.maxpool_conv(x)
#print("input", x)
print("last",x1)
#print("last",x1)
return self.maxpool_conv(x)
@ -69,16 +69,20 @@ class Up(nn.Module):
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
# x1,2 = x5, x4
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
#print(x1.size(), "x1 pad", x1)
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
#print(x1.size(), "x1 pad", x1)
# print("diffY, X", diffY, diffX, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)
x = torch.cat([x2, x1], dim=1)
#print("bn running mean, var", self.conv.double_conv[4].running_var)
#print(x.size(), "cat",x)
x3 = self.conv(x)
print("x3", x3)
return self.conv(x)

View File

@ -135,13 +135,18 @@ Tensor *down4_conv_0_output = new Tensor({1, 1024, 40, 59});
Tensor *down4_batchnorm_0_output = new Tensor({1, 1024, 40, 59});
Tensor *down4_conv_1_output = new Tensor({1, 1024, 40, 59});
Tensor *down4_batchnorm_1_output = new Tensor({1, 1024, 40, 59});
Tensor *up1_convt_0_output = new Tensor({1, 512, 80, 119});
Tensor *up1_concat_0_output = new Tensor({1, 512, 80, 119});
Tensor *up2_convt_0_output = new Tensor({1, 256, 160, 239});
Tensor *up1_convt_0_output = new Tensor({1, 512, 80, 118});
Tensor *up1_concat_0_output = new Tensor({1, 1024, 80, 119});
Tensor *up1_conv_0_output = new Tensor({1, 512, 80, 119});
Tensor *up1_batchnorm_0_output = new Tensor({1, 512, 80, 119});
Tensor *up1_conv_1_output = new Tensor({1, 512, 80, 119});
Tensor *up1_batchnorm_1_output = new Tensor({1, 512, 80, 119});
Tensor *up2_convt_0_output = new Tensor({1, 256, 160, 238});
Tensor *up2_concat_0_output = new Tensor({1, 256, 160, 239});
Tensor *up3_convt_0_output = new Tensor({1, 128, 320, 479});
Tensor *up3_convt_0_output = new Tensor({1, 128, 320, 478});
Tensor *up3_concat_0_output = new Tensor({1, 128, 320, 479});
Tensor *up4_convt_0_output = new Tensor({1, 64, 640, 959});
Tensor *up4_convt_0_output = new Tensor({1, 64, 640, 958});
Tensor *up4_concat_0_output = new Tensor({1, 64, 640, 959});
Tensor *outc_conv_0_output = new Tensor({1, 2, 640, 959});
@ -208,11 +213,19 @@ void uNet(Tensor *input, Tensor *output) {
ReLU(down4_batchnorm_1_output);
/*
// up1(1024, 512), (down4_batchnorm_1_output, down3_batchnorm_1_output)
ConvTranspose2d(down4_batchnorm_1_output, up1_up_weight, up1_up_bias, up1_convt_0_output, 2, 0);
Concat(up1_convt_0_output, down3_batchnorm_1_output, up1_concat_0_output);
Concat(up1_convt_0_output, down3_batchnorm_1_output, up1_concat_0_output);
Conv2d(up1_concat_0_output, up1_conv_double_conv_0_weight, NULL, up1_conv_0_output, 1, 1, 1, false);
BatchNorm2d(up1_conv_0_output, up1_conv_double_conv_1_weight, up1_conv_double_conv_1_bias, up1_batchnorm_0_running_mean, up1_batchnorm_0_running_var, up1_batchnorm_0_output, 1e-5, 0.1);
ReLU(up1_batchnorm_0_output);
Conv2d(up1_batchnorm_0_output, up1_conv_double_conv_3_weight, NULL, up1_conv_1_output, 1, 1, 1, false);
BatchNorm2d(up1_conv_1_output, up1_conv_double_conv_4_weight, up1_conv_double_conv_4_bias, up1_batchnorm_1_running_mean, up1_batchnorm_1_running_var, up1_batchnorm_1_output, 1e-5, 0.1);
ReLU(up1_batchnorm_1_output);
/*
// up2(512, 256), (up1_concat_0_output, down2_batchnorm_1_output)
ConvTranspose2d(up1_concat_0_output, up2_up_weight, up2_up_bias, up2_convt_0_output, 2, 0);
Concat(up2_convt_0_output, down2_batchnorm_1_output, up2_concat_0_output);
@ -224,7 +237,8 @@ void uNet(Tensor *input, Tensor *output) {
// up4(128, 64), (up3_concat_0_output, inc_batchnorm_1_output)
ConvTranspose2d(up3_concat_0_output, up4_up_weight, up4_up_bias, up4_convt_0_output, 2, 0);
Concat(up4_convt_0_output, down2_batchnorm_1_output, up4_concat_0_output);
*/
/*
// outc(64, n_classes)
Conv2d(up4_concat_0_output, outc_conv_weight, outc_conv_bias, output_tensor, 1, 0, 0, true);
output = output_tensor->buf;
@ -411,8 +425,40 @@ void MaxPool2d(Tensor *input, Tensor *output){
}
}
void Concat(Tensor *input1, Tensor *input2, Tensor *output){
// TODO
int C = input1->shape[1], H = input1->shape[2], W = input1->shape[3];
int C2 = input2->shape[1], H2 = input2->shape[2], W2 = input2->shape[3];
int OC = output->shape[1], OH = output->shape[2], OW = output->shape[3];
//printf("C %d, H %d, W %d, C2 %d, H2 %d, W2 %d, OC %d, OH %d, OW %d\n", C, H, W, C2, H2, W2, OC, OH, OW);
for (int oc=0; oc<OC/2; ++oc){
for (int oh=0; oh<OH; ++oh){
for (int ow=0; ow<OW; ++ow){
output->buf[oc * OH * OW + oh * OW + ow] = input2->buf[oc * OH * OW + oh * OW + ow];
}
}
}
printf("input1\n");
for (int oc=OC/2; oc<OC; ++oc){
for (int oh=0; oh<OH; ++oh){
for (int ow=0; ow<OW; ++ow){
if (ow == OW-1) {
output->buf[oc * OH * OW + oh * OW + ow] = 0.0;
//printf("padding!\n");
}
else {
output->buf[oc * OH * OW + oh * OW + ow] = input1->buf[(oc-OC/2) * H * W + oh * W + ow];
//printf("[%d] %f\n", (oc-OC/2) * H * W + oh * W + ow, input1->buf[(oc-OC/2) * H * W + oh * W + ow]);
}
}
}
}
}