#include #include float vectordot_naive(float *A, float *B, int N) { float c = 0.f; for (int i = 0; i < N; ++i) { c += A[i] * B[i]; } return c; } float vectordot_fma(float *A, float *B, int N) { float c = 0.f; /* TODO: FILL IN HERE */ __m256 sum = _mm256_setzero_ps(); for (int i = 0; i < N; i = i + 8) { __m256 A_ = _mm256_load_ps(A + i); __m256 B_ = _mm256_load_ps(B + i); sum = _mm256_fmadd_ps(A_, B_, sum); } const __m128 hiQuad = _mm256_extractf128_ps(sum, 1); const __m128 loQuad = _mm256_castps256_ps128(sum); const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad); const __m128 loDual = sumQuad; const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad); const __m128 sumDual = _mm_add_ps(loDual, hiDual); const __m128 lo = sumDual; const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1); const __m128 res = _mm_add_ss(lo, hi); c = _mm_cvtss_f32(res); return c; }