chundoong-lab-ta/APWS23/convolution-skeleton/util.cpp

116 lines
3.4 KiB
C++
Raw Normal View History

2023-02-01 20:04:35 +09:00
#include "util.h"
#include <math.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
double get_current_time() {
struct timespec tv;
clock_gettime(CLOCK_MONOTONIC, &tv);
return tv.tv_sec + tv.tv_nsec * 1e-9;
}
void alloc_tensor(float **m, int N, int C, int H, int W) {
2023-02-14 01:23:28 +09:00
*m = (float *) aligned_alloc(32, N * C * H * W * sizeof(float));
2023-02-01 20:04:35 +09:00
if (*m == NULL) {
printf("Failed to allocate memory for tensor.\n");
exit(0);
}
}
void rand_tensor(float *m, int N, int C, int H, int W) {
int L = N * C * H * W;
2023-02-14 01:23:28 +09:00
for (int j = 0; j < L; j++) { m[j] = (float) rand() / RAND_MAX - 0.5; }
2023-02-01 20:04:35 +09:00
}
void zero_tensor(float *m, int N, int C, int H, int W) {
int L = N * C * H * W;
memset(m, 0, sizeof(float) * L);
}
void print_tensor(float *m, int N, int C, int H, int W) {
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
printf("Batch %d, Channel %d\n", n, c);
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
printf("%+.3f ", m[((n * C + c) * H + h) * W + w]);
}
printf("\n");
}
}
}
}
2023-02-14 01:23:28 +09:00
void check_convolution(float *I, float *F, float *O, int N, int C, int H, int W,
int K, int R, int S, int pad_h, int pad_w, int stride_h,
int stride_w, int dilation_h, int dilation_w) {
2023-02-01 20:04:35 +09:00
float *O_ans;
const int ON = N;
const int OC = K;
const int OH = 1 + (H + 2 * pad_h - (((R - 1) * dilation_h) + 1)) / stride_h;
const int OW = 1 + (W + 2 * pad_w - (((S - 1) * dilation_w) + 1)) / stride_w;
alloc_tensor(&O_ans, ON, OC, OH, OW);
zero_tensor(O_ans, ON, OC, OH, OW);
2023-02-14 01:23:28 +09:00
#pragma omp parallel for
2023-02-01 20:04:35 +09:00
for (int on = 0; on < ON; ++on) {
for (int oc = 0; oc < OC; ++oc) {
for (int oh = 0; oh < OH; ++oh) {
for (int ow = 0; ow < OW; ++ow) {
float sum = 0;
for (int c = 0; c < C; ++c) {
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
const int n = on;
const int h = oh * stride_h - pad_h + r * dilation_h;
const int w = ow * stride_w - pad_w + s * dilation_w;
const int k = oc;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
2023-02-14 01:23:28 +09:00
sum += I[((n * C + c) * H + h) * W + w] *
F[((k * C + c) * R + r) * S + s];
2023-02-01 20:04:35 +09:00
}
}
}
O_ans[((on * OC + oc) * OH + oh) * OW + ow] = sum;
}
}
}
}
bool is_valid = true;
int cnt = 0, thr = 10;
float eps = 1e-3;
for (int on = 0; on < ON; ++on) {
for (int oc = 0; oc < OC; ++oc) {
for (int oh = 0; oh < OH; ++oh) {
for (int ow = 0; ow < OW; ++ow) {
float o = O[((on * OC + oc) * OH + oh) * OW + ow];
float o_ans = O_ans[((on * OC + oc) * OH + oh) * OW + ow];
if (fabsf(o - o_ans) > eps &&
(o_ans == 0 || fabsf((o - o_ans) / o_ans) > eps)) {
++cnt;
if (cnt <= thr)
2023-02-14 01:23:28 +09:00
printf(
"O[%d][%d][%d][%d] : correct_value = %f, your_value = %f\n",
on, oc, oh, ow, o_ans, o);
2023-02-01 20:04:35 +09:00
if (cnt == thr + 1)
2023-02-14 01:23:28 +09:00
printf("Too many error, only first %d values are printed.\n",
thr);
2023-02-01 20:04:35 +09:00
is_valid = false;
}
}
}
}
}
if (is_valid) {
printf("Validation Result: VALID\n");
} else {
printf("Validation Result: INVALID\n");
}
}