2022-10-04 13:56:31 +09:00
|
|
|
#include <pthread.h>
|
|
|
|
#include <stdio.h>
|
|
|
|
#include <stdlib.h>
|
|
|
|
|
|
|
|
struct thread_arg {
|
|
|
|
const float *A;
|
|
|
|
const float *B;
|
|
|
|
float *C;
|
|
|
|
int M;
|
|
|
|
int N;
|
|
|
|
int K;
|
|
|
|
int num_threads;
|
|
|
|
int rank; /* id of this thread */
|
|
|
|
} args[256];
|
2022-10-04 21:07:02 +09:00
|
|
|
|
2022-10-04 13:56:31 +09:00
|
|
|
static pthread_t threads[256];
|
2022-10-04 21:07:02 +09:00
|
|
|
|
|
|
|
static void *matmul_kernel(void *arg) {
|
|
|
|
struct thread_arg *_arg = (struct thread_arg *)arg;
|
2022-10-04 13:56:31 +09:00
|
|
|
const float *A = _arg->A;
|
|
|
|
const float *B = _arg->B;
|
|
|
|
float *C = _arg->C;
|
|
|
|
int M = _arg->M;
|
|
|
|
int N = _arg->N;
|
|
|
|
int K = _arg->K;
|
2022-10-04 21:07:02 +09:00
|
|
|
int num_threads = _arg->num_threads;
|
2022-10-04 13:56:31 +09:00
|
|
|
int rank = _arg->rank;
|
|
|
|
|
2022-10-04 21:07:02 +09:00
|
|
|
int start = ((M + num_threads - 1) / num_threads) * rank;
|
|
|
|
int end = ((M + num_threads - 1) / num_threads) * (rank + 1);
|
2022-10-04 13:56:31 +09:00
|
|
|
|
2022-10-04 21:07:02 +09:00
|
|
|
if (end >= M)
|
|
|
|
end = M;
|
2022-10-04 13:56:31 +09:00
|
|
|
for (int i = start; i < end; ++i) {
|
|
|
|
for (int k = 0; k < K; ++k) {
|
|
|
|
for (int j = 0; j < N; ++j) {
|
|
|
|
C[i * N + j] += A[i * K + k] * B[k * N + j];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
|
2022-10-04 21:07:02 +09:00
|
|
|
void matmul(const float *A, const float *B, float *C, int M, int N, int K,
|
|
|
|
int num_threads) {
|
2022-10-04 13:56:31 +09:00
|
|
|
int err;
|
|
|
|
for (int t = 0; t < num_threads; ++t) {
|
2022-10-04 21:07:02 +09:00
|
|
|
args[t].A = A, args[t].B = B, args[t].C = C, args[t].M = M, args[t].N = N,
|
|
|
|
args[t].K = K, args[t].num_threads = num_threads, args[t].rank = t;
|
|
|
|
err = pthread_create(&threads[t], NULL, matmul_kernel, (void *)&args[t]);
|
2022-10-04 13:56:31 +09:00
|
|
|
if (err) {
|
|
|
|
printf("pthread_create(%d) failed with err %d\n", t, err);
|
|
|
|
exit(EXIT_FAILURE);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int t = 0; t < num_threads; ++t) {
|
|
|
|
err = pthread_join(threads[t], NULL);
|
|
|
|
if (err) {
|
|
|
|
printf("pthread_join(%d) failed with err %d\n", t, err);
|
|
|
|
exit(EXIT_FAILURE);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|