chundoong-lab-ta/SHPC2022/hw2/matmul/matmul.c

87 lines
2.0 KiB
C
Raw Normal View History

2022-10-04 13:56:31 +09:00
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
struct thread_arg {
const float *A;
const float *B;
float *C;
int M;
int N;
int K;
int num_threads;
int rank; /* id of this thread */
} args[256];
static pthread_t threads[256];
2022-10-04 21:07:02 +09:00
static void *matmul_kernel(void *arg) {
2022-10-04 13:56:31 +09:00
/*
2022-10-04 21:07:02 +09:00
TODO: FILL IN HERE
2022-10-04 13:56:31 +09:00
*/
2022-11-07 12:26:51 +09:00
struct thread_arg *input = (struct thread_arg *)arg;
const float *A = (*input).A;
const float *B = (*input).B;
float *C = (*input).C;
int M = (*input).M;
int N = (*input).N;
int K = (*input).K;
int num_threads = (*input).num_threads;
int rank = (*input).rank;
int div = M / num_threads;
/*
C[i][j] = sum of A[i][k]* B[k][j]
C[i][j] = C[i*N + j]
A[i][k] = A[i*K + j]
B[k][j] = B[k*N + j]
*/
for (int k = 0; k < N; ++k) {
for (int i = rank * div; i < (rank + 1) * div; ++i) {
for (int j = 0; j < K; ++j) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}
}
if (M % num_threads != 0) {
for (int k = 0; k < N; ++k) {
for (int i = div * num_threads; i < M; ++i) {
for (int j = 0; j < K; ++j) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}
}
}
2022-10-04 13:56:31 +09:00
return NULL;
}
2022-10-04 21:07:02 +09:00
void matmul(const float *A, const float *B, float *C, int M, int N, int K,
int num_threads) {
2022-10-04 13:56:31 +09:00
if (num_threads > 256) {
fprintf(stderr, "num_threads must be <= 256\n");
exit(EXIT_FAILURE);
}
int err;
for (int t = 0; t < num_threads; ++t) {
2022-10-04 21:07:02 +09:00
args[t].A = A, args[t].B = B, args[t].C = C, args[t].M = M, args[t].N = N,
args[t].K = K, args[t].num_threads = num_threads, args[t].rank = t;
err = pthread_create(&threads[t], NULL, matmul_kernel, (void *)&args[t]);
2022-10-04 13:56:31 +09:00
if (err) {
printf("pthread_create(%d) failed with err %d\n", t, err);
exit(EXIT_FAILURE);
}
}
for (int t = 0; t < num_threads; ++t) {
err = pthread_join(threads[t], NULL);
if (err) {
printf("pthread_join(%d) failed with err %d\n", t, err);
exit(EXIT_FAILURE);
}
}
}