2023-02-14 01:23:28 +09:00
|
|
|
#include <cublas_v2.h>
|
|
|
|
#include <cuda_runtime_api.h>
|
2023-02-01 20:04:35 +09:00
|
|
|
|
|
|
|
#include <cstdio>
|
|
|
|
#include <cstdlib>
|
2023-02-14 01:23:28 +09:00
|
|
|
|
|
|
|
#include "matmul.h"
|
2023-02-01 20:04:35 +09:00
|
|
|
|
|
|
|
#define CHECK_CUDA(call) \
|
|
|
|
do { \
|
|
|
|
cudaError_t status_ = call; \
|
|
|
|
if (status_ != cudaSuccess) { \
|
|
|
|
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \
|
|
|
|
cudaGetErrorString(status_)); \
|
|
|
|
exit(EXIT_FAILURE); \
|
|
|
|
} \
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
#define CHECK_CUBLAS(call) \
|
|
|
|
do { \
|
|
|
|
cublasStatus_t status_ = call; \
|
|
|
|
if (status_ != CUBLAS_STATUS_SUCCESS) { \
|
|
|
|
fprintf(stderr, "CUBLAS error (%s:%d): %s, %s\n", __FILE__, __LINE__, \
|
|
|
|
cublasGetStatusName(status_), cublasGetStatusString(status_)); \
|
|
|
|
exit(EXIT_FAILURE); \
|
|
|
|
} \
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
static float *A_gpu, *B_gpu, *C_gpu;
|
|
|
|
static cublasHandle_t handle;
|
|
|
|
|
|
|
|
void matmul_cublas_initialize(size_t M, size_t N, size_t K) {
|
|
|
|
// TODO: Implement here
|
|
|
|
|
|
|
|
// 1. Create cublas handle
|
|
|
|
CHECK_CUBLAS(cublasCreate(&handle));
|
|
|
|
|
|
|
|
// 2. Allocate GPU memory for A, B, C
|
|
|
|
// cudaMalloc(&A_gpu, ...);
|
|
|
|
// cudaMalloc(&B_gpu, ...);
|
|
|
|
// cudaMalloc(&C_gpu, ...);
|
|
|
|
}
|
|
|
|
|
|
|
|
void matmul_cublas(float *A, float *B, float *C, size_t M, size_t N, size_t K) {
|
|
|
|
// TODO: Implement here
|
|
|
|
|
|
|
|
// 1. Send A from CPU to GPU
|
|
|
|
// cublasSetMatrix(...);
|
|
|
|
|
|
|
|
// 2. Send B from CPU to GPU
|
|
|
|
// cublasSetMatrix(...);
|
|
|
|
|
|
|
|
// 3. Run SGEMM
|
|
|
|
const float one = 1, zero = 0;
|
|
|
|
// cublasSgemm(handle, ...);
|
|
|
|
|
|
|
|
// 4. Send C from GPU to CPU
|
|
|
|
// cublasGetMatrix(...);
|
|
|
|
}
|
|
|
|
|
|
|
|
void matmul_cublas_finalize(size_t M, size_t N, size_t K) {
|
|
|
|
// TODO: Implement here
|
|
|
|
|
|
|
|
// 1. Free GPU memory for A, B, C
|
|
|
|
// cudaFree(...);
|
|
|
|
// cudaFree(...);
|
|
|
|
// cudaFree(...);
|
|
|
|
|
|
|
|
// 2. Destroy cublas handle
|
|
|
|
// cublasDestroy(...);
|
|
|
|
}
|