40 lines
954 B
C
40 lines
954 B
C
#include <immintrin.h>
|
|
#include <math.h>
|
|
|
|
float vectordot_naive(float *A, float *B, int N) {
|
|
float c = 0.f;
|
|
for (int i = 0; i < N; ++i) {
|
|
c += A[i] * B[i];
|
|
}
|
|
return c;
|
|
}
|
|
|
|
float vectordot_fma(float *A, float *B, int N) {
|
|
float c = 0.f;
|
|
/*
|
|
TODO: FILL IN HERE
|
|
*/
|
|
|
|
__m256 sum = _mm256_setzero_ps();
|
|
for (int i = 0; i < N; i = i + 8) {
|
|
__m256 A_ = _mm256_load_ps(A + i);
|
|
__m256 B_ = _mm256_load_ps(B + i);
|
|
sum = _mm256_fmadd_ps(A_, B_, sum);
|
|
}
|
|
const __m128 hiQuad = _mm256_extractf128_ps(sum, 1);
|
|
const __m128 loQuad = _mm256_castps256_ps128(sum);
|
|
const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);
|
|
|
|
const __m128 loDual = sumQuad;
|
|
const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);
|
|
const __m128 sumDual = _mm_add_ps(loDual, hiDual);
|
|
|
|
const __m128 lo = sumDual;
|
|
const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);
|
|
const __m128 res = _mm_add_ss(lo, hi);
|
|
|
|
c = _mm_cvtss_f32(res);
|
|
|
|
return c;
|
|
}
|