chundoong-lab-ta/SHPC2022/hw2_answer/vectordot/vectordot.c

38 lines
793 B
C

#include <immintrin.h>
#include <math.h>
float vectordot_naive(float *A, float *B, int N) {
/*
TODO: FILL IN HERE
*/
float c = 0.f;
for (int i = 0; i < N; ++i) {
c += A[i] * B[i];
}
return c;
}
float vectordot_fma(float *A, float *B, int N) {
/*
TODO: FILL IN HERE
*/
__m256 cvec = _mm256_set_ps(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
int i = 0;
for (i = 0; i < N / 8; ++i) {
__m256 avec = _mm256_load_ps(&A[i * 8]);
__m256 bvec = _mm256_load_ps(&B[i * 8]);
cvec = _mm256_fmadd_ps(avec, bvec, cvec);
}
float c = 0.f;
for (i = 0; i < N % 8; ++i) {
c = fmaf(A[8 * (N / 8) + i], B[8 * (N / 8) + i], c);
}
float *vecp = &cvec;
c += vecp[0] + vecp[1] + vecp[2] + vecp[3] + vecp[4] + vecp[5] + vecp[6] +
vecp[7];
return c;
}