[Refactor] minor fix

This commit is contained in:
Jaehwan Lee 2024-04-30 09:47:07 +00:00
parent 205831df32
commit 7b254040f6
2 changed files with 50 additions and 47 deletions

View File

@ -314,6 +314,25 @@ void scaling(Activation* inout, float scale) {
}
}
/* Generate mask
* @param [in & out] inout: [s, s]
* 's' is the number of tokens in the prompt.
*/
void generate_mask(Activation* inout) {
size_t s = inout->shape[0];
for (size_t i = 0; i < s; i++) {
for (size_t j = 0; j < s; j++) {
if (i >= j) {
inout->buf[i*s + j] = 0;
} else {
inout->buf[i*s + j] = -1e10;
}
}
}
}
/* (Elem-wise) Masking
* @param [in1 & out] inout: [N]
* @param [in2] mask: [N]
@ -356,29 +375,6 @@ void add(Activation* inout, Activation* x) {
}
}
/* Greedy Max Sampling
* @param [in1] in: [s, V]
* @return [ret] out: [1]
* 's' is the number of tokens in the prompt.
* 'V' is the number of vocabulary.
*/
int greedy_sampling(Activation* in) {
size_t s = in->shape[0];
size_t V = in->shape[1];
int out = 0;
float max = -INFINITY;
for (size_t i = 0; i < V; i++) {
if (in->buf[(s-1)*V + i] > max) {
max = in->buf[(s-1)*V + i];
out = i;
}
}
return out;
}
/* Split into QKV
* @param [in1] in: [s, H]
* @param [out] out: [3, s, H/3]
@ -423,25 +419,6 @@ void split_head(Activation* in, size_t n_head,
}
}
/* Generate mask
* @param [in & out] inout: [s, s]
* 's' is the number of tokens in the prompt.
*/
void generate_mask(Activation* inout) {
size_t s = inout->shape[0];
for (size_t i = 0; i < s; i++) {
for (size_t j = 0; j < s; j++) {
if (i >= j) {
inout->buf[i*s + j] = 0;
} else {
inout->buf[i*s + j] = -1e10;
}
}
}
}
/* Extract Q, K, V from QKV head
* @param [in1] in: [3, n_head, s, H_]
* @param [in2] head_idx: [1]
@ -471,6 +448,7 @@ void extract_qkv(Activation* in, size_t head_idx, size_t n_head,
/* Merge each heads
* @param [in1] in: [s, H_]
* @param [in2] head_idx: [1]
* @param [in3] n_head: [1]
* @param [out] out: [n_head, s, H_]
* 's' is the number of tokens in the prompt.
* 'H_' is the hidden dimension/n_head.
@ -513,6 +491,29 @@ void concat_head(Activation* in,
}
}
/* Greedy Max Sampling
* @param [in1] in: [s, V]
* @return [ret] out: [1]
* 's' is the number of tokens in the prompt.
* 'V' is the number of vocabulary.
*/
int greedy_sampling(Activation* in) {
size_t s = in->shape[0];
size_t V = in->shape[1];
int out = 0;
float max = -INFINITY;
for (size_t i = 0; i < V; i++) {
if (in->buf[(s-1)*V + i] > max) {
max = in->buf[(s-1)*V + i];
out = i;
}
}
return out;
}
/* (Position-wise) Feed-Forward Network
* @param [in1] in: [input_tokens, HIDDEN_DIM]
* @param [in2] mlp1_w: [HIDDEN_DIM, 4*HIDDEN_DIM]
@ -583,7 +584,7 @@ void mha(Activation* in,
/* QKV projection:
[input_tokens, HIDDEN_DIM] ->
[input_tokens, 3*HIDDEN_DIM]) */
[input_tokens, 3*HIDDEN_DIM] */
linear(in, attn_w, attn_b, mha_qkv_proj_a);
/* Split into Q, K, V:

View File

@ -52,6 +52,7 @@ typedef Tensor Activation;
/* [Model Operations] */
/* Elem-wise operation */
void gelu(Activation* inout);
void copy(Activation* in, Activation* out);
void add(Activation* inout, Activation* x);
void scaling(Activation* inout, float scale);
void masking(Activation* inout, Activation* mask);
@ -62,17 +63,18 @@ void linear(Activation* in, Parameter* w, Parameter* b,
void transpose(Activation* in, Activation* out);
void split_qkv(Activation* in, Activation* out);
void split_head(Activation* in, size_t n_head, Activation* out);
void merge_head(Activation* in, size_t head_idx, size_t n_head,
Activation* out);
void concat_head(Activation* in, Activation* out);
/* Other operation */
void token_pos_embedding(vector<int> in, Parameter* wte, Parameter* wpe,
Activation* out);
void softmax(Activation* inout);
void layer_norm(Activation* inout, Parameter* gamma, Parameter* beta);
void copy(Activation* in, Activation* out);
int greedy_sampling(Activation* in);
void generate_mask(Activation* inout);
void extract_qkv(Activation* in, size_t head_idx, size_t n_head,
Activation* q, Activation* k, Activation* v);
void merge_head(Activation* in, size_t head_idx, size_t n_head,
Activation* out);
int greedy_sampling(Activation* in);
/* [Model Construction] */