#include #include "matmul.h" static __global__ void matmul_kernel(float *A, float *B, float *C, int M, int N, int K) { int i = blockDim.x * blockIdx.x + threadIdx.x; int j = blockDim.y * blockIdx.y + threadIdx.y; if (i >= M || j >= N) return; float sum = 0.0; for (int k = 0; k < K; ++k) sum += A[i * K + k] * B[k * N + j]; C[i * N + j] = sum; } #define BLOCKS 4 static size_t Mbegin[BLOCKS], Mend[BLOCKS]; static cudaStream_t data_stream, calc_stream; static cudaEvent_t events[BLOCKS]; static float *A_gpu, *B_gpu, *C_gpu; void matmul_buffering_initialize(size_t M, size_t N, size_t K) { for (size_t i = 0; i < BLOCKS; i++) { Mbegin[i] = M / BLOCKS * i; Mend[i] = M / BLOCKS * (i + 1); if (i == BLOCKS - 1) Mend[i] = M; } cudaStreamCreate(&data_stream); cudaStreamCreate(&calc_stream); for (int i = 0; i < BLOCKS; i++) { cudaEventCreate(&events[i]); } cudaMalloc(&A_gpu, M * K * sizeof(float)); cudaMalloc(&B_gpu, K * N * sizeof(float)); cudaMalloc(&C_gpu, M * N * sizeof(float)); } void matmul_buffering(float *A, float *B, float *C, size_t M, size_t N, size_t K) { cudaMemcpyAsync(B_gpu, B, K * N * sizeof(float), cudaMemcpyHostToDevice, data_stream); for (int i = 0; i < BLOCKS; i++) { cudaMemcpyAsync(&A_gpu[Mbegin[i] * K], &A[Mbegin[i] * K], (Mend[i] - Mbegin[i]) * K * sizeof(float), cudaMemcpyHostToDevice, data_stream); cudaEventRecord(events[i], data_stream); } for (int i = 0; i < BLOCKS; i++) { dim3 blockDim(32, 32); dim3 gridDim((Mend[i] - Mbegin[i] + 32 - 1) / 32, (N + 32 - 1) / 32); cudaStreamWaitEvent(calc_stream, events[i]); matmul_kernel<<>>(&A_gpu[Mbegin[i] * K], B_gpu, &C_gpu[Mbegin[i] * N], (Mend[i] - Mbegin[i]), N, K); } cudaStreamSynchronize(calc_stream); cudaMemcpyAsync(C, C_gpu, M * N * sizeof(float), cudaMemcpyDeviceToHost, data_stream); cudaStreamSynchronize(data_stream); } void matmul_buffering_finalize(size_t M, size_t N, size_t K) { cudaFree(A_gpu); cudaFree(B_gpu); cudaFree(C_gpu); cudaStreamDestroy(data_stream); cudaStreamDestroy(calc_stream); for (int i = 0; i < BLOCKS; i++) { cudaEventDestroy(events[i]); } }