implement up1
This commit is contained in:
parent
a594dc0ec5
commit
b2673ed518
Binary file not shown.
|
@ -48,7 +48,7 @@ class Down(nn.Module):
|
|||
#print("param:", list(self.maxpool_conv[1].double_conv[4].state_dict()['bias']))
|
||||
x1 = self.maxpool_conv(x)
|
||||
#print("input", x)
|
||||
print("last",x1)
|
||||
#print("last",x1)
|
||||
return self.maxpool_conv(x)
|
||||
|
||||
|
||||
|
@ -69,16 +69,20 @@ class Up(nn.Module):
|
|||
def forward(self, x1, x2):
|
||||
x1 = self.up(x1)
|
||||
# input is CHW
|
||||
# x1,2 = x5, x4
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
|
||||
#print(x1.size(), "x1 pad", x1)
|
||||
|
||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
||||
diffY // 2, diffY - diffY // 2])
|
||||
# if you have padding issues, see
|
||||
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
||||
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
||||
#print(x1.size(), "x1 pad", x1)
|
||||
|
||||
# print("diffY, X", diffY, diffX, diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
#print("bn running mean, var", self.conv.double_conv[4].running_var)
|
||||
#print(x.size(), "cat",x)
|
||||
x3 = self.conv(x)
|
||||
print("x3", x3)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
|
|
@ -135,13 +135,18 @@ Tensor *down4_conv_0_output = new Tensor({1, 1024, 40, 59});
|
|||
Tensor *down4_batchnorm_0_output = new Tensor({1, 1024, 40, 59});
|
||||
Tensor *down4_conv_1_output = new Tensor({1, 1024, 40, 59});
|
||||
Tensor *down4_batchnorm_1_output = new Tensor({1, 1024, 40, 59});
|
||||
Tensor *up1_convt_0_output = new Tensor({1, 512, 80, 119});
|
||||
Tensor *up1_concat_0_output = new Tensor({1, 512, 80, 119});
|
||||
Tensor *up2_convt_0_output = new Tensor({1, 256, 160, 239});
|
||||
Tensor *up1_convt_0_output = new Tensor({1, 512, 80, 118});
|
||||
Tensor *up1_concat_0_output = new Tensor({1, 1024, 80, 119});
|
||||
Tensor *up1_conv_0_output = new Tensor({1, 512, 80, 119});
|
||||
Tensor *up1_batchnorm_0_output = new Tensor({1, 512, 80, 119});
|
||||
Tensor *up1_conv_1_output = new Tensor({1, 512, 80, 119});
|
||||
Tensor *up1_batchnorm_1_output = new Tensor({1, 512, 80, 119});
|
||||
|
||||
Tensor *up2_convt_0_output = new Tensor({1, 256, 160, 238});
|
||||
Tensor *up2_concat_0_output = new Tensor({1, 256, 160, 239});
|
||||
Tensor *up3_convt_0_output = new Tensor({1, 128, 320, 479});
|
||||
Tensor *up3_convt_0_output = new Tensor({1, 128, 320, 478});
|
||||
Tensor *up3_concat_0_output = new Tensor({1, 128, 320, 479});
|
||||
Tensor *up4_convt_0_output = new Tensor({1, 64, 640, 959});
|
||||
Tensor *up4_convt_0_output = new Tensor({1, 64, 640, 958});
|
||||
Tensor *up4_concat_0_output = new Tensor({1, 64, 640, 959});
|
||||
Tensor *outc_conv_0_output = new Tensor({1, 2, 640, 959});
|
||||
|
||||
|
@ -208,11 +213,19 @@ void uNet(Tensor *input, Tensor *output) {
|
|||
ReLU(down4_batchnorm_1_output);
|
||||
|
||||
|
||||
/*
|
||||
// up1(1024, 512), (down4_batchnorm_1_output, down3_batchnorm_1_output)
|
||||
ConvTranspose2d(down4_batchnorm_1_output, up1_up_weight, up1_up_bias, up1_convt_0_output, 2, 0);
|
||||
Concat(up1_convt_0_output, down3_batchnorm_1_output, up1_concat_0_output);
|
||||
|
||||
Concat(up1_convt_0_output, down3_batchnorm_1_output, up1_concat_0_output);
|
||||
|
||||
Conv2d(up1_concat_0_output, up1_conv_double_conv_0_weight, NULL, up1_conv_0_output, 1, 1, 1, false);
|
||||
BatchNorm2d(up1_conv_0_output, up1_conv_double_conv_1_weight, up1_conv_double_conv_1_bias, up1_batchnorm_0_running_mean, up1_batchnorm_0_running_var, up1_batchnorm_0_output, 1e-5, 0.1);
|
||||
ReLU(up1_batchnorm_0_output);
|
||||
Conv2d(up1_batchnorm_0_output, up1_conv_double_conv_3_weight, NULL, up1_conv_1_output, 1, 1, 1, false);
|
||||
BatchNorm2d(up1_conv_1_output, up1_conv_double_conv_4_weight, up1_conv_double_conv_4_bias, up1_batchnorm_1_running_mean, up1_batchnorm_1_running_var, up1_batchnorm_1_output, 1e-5, 0.1);
|
||||
ReLU(up1_batchnorm_1_output);
|
||||
|
||||
|
||||
/*
|
||||
// up2(512, 256), (up1_concat_0_output, down2_batchnorm_1_output)
|
||||
ConvTranspose2d(up1_concat_0_output, up2_up_weight, up2_up_bias, up2_convt_0_output, 2, 0);
|
||||
Concat(up2_convt_0_output, down2_batchnorm_1_output, up2_concat_0_output);
|
||||
|
@ -224,7 +237,8 @@ void uNet(Tensor *input, Tensor *output) {
|
|||
// up4(128, 64), (up3_concat_0_output, inc_batchnorm_1_output)
|
||||
ConvTranspose2d(up3_concat_0_output, up4_up_weight, up4_up_bias, up4_convt_0_output, 2, 0);
|
||||
Concat(up4_convt_0_output, down2_batchnorm_1_output, up4_concat_0_output);
|
||||
|
||||
*/
|
||||
/*
|
||||
// outc(64, n_classes)
|
||||
Conv2d(up4_concat_0_output, outc_conv_weight, outc_conv_bias, output_tensor, 1, 0, 0, true);
|
||||
output = output_tensor->buf;
|
||||
|
@ -411,8 +425,40 @@ void MaxPool2d(Tensor *input, Tensor *output){
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
void Concat(Tensor *input1, Tensor *input2, Tensor *output){
|
||||
// TODO
|
||||
|
||||
int C = input1->shape[1], H = input1->shape[2], W = input1->shape[3];
|
||||
int C2 = input2->shape[1], H2 = input2->shape[2], W2 = input2->shape[3];
|
||||
int OC = output->shape[1], OH = output->shape[2], OW = output->shape[3];
|
||||
|
||||
//printf("C %d, H %d, W %d, C2 %d, H2 %d, W2 %d, OC %d, OH %d, OW %d\n", C, H, W, C2, H2, W2, OC, OH, OW);
|
||||
for (int oc=0; oc<OC/2; ++oc){
|
||||
for (int oh=0; oh<OH; ++oh){
|
||||
for (int ow=0; ow<OW; ++ow){
|
||||
output->buf[oc * OH * OW + oh * OW + ow] = input2->buf[oc * OH * OW + oh * OW + ow];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
printf("input1\n");
|
||||
for (int oc=OC/2; oc<OC; ++oc){
|
||||
for (int oh=0; oh<OH; ++oh){
|
||||
for (int ow=0; ow<OW; ++ow){
|
||||
if (ow == OW-1) {
|
||||
output->buf[oc * OH * OW + oh * OW + ow] = 0.0;
|
||||
//printf("padding!\n");
|
||||
}
|
||||
else {
|
||||
output->buf[oc * OH * OW + oh * OW + ow] = input1->buf[(oc-OC/2) * H * W + oh * W + ow];
|
||||
//printf("[%d] %f\n", (oc-OC/2) * H * W + oh * W + ow, input1->buf[(oc-OC/2) * H * W + oh * W + ow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue