[Refactor] minor fix
This commit is contained in:
parent
205831df32
commit
7b254040f6
|
@ -314,6 +314,25 @@ void scaling(Activation* inout, float scale) {
|
|||
}
|
||||
}
|
||||
|
||||
/* Generate mask
|
||||
* @param [in & out] inout: [s, s]
|
||||
* 's' is the number of tokens in the prompt.
|
||||
*/
|
||||
void generate_mask(Activation* inout) {
|
||||
|
||||
size_t s = inout->shape[0];
|
||||
|
||||
for (size_t i = 0; i < s; i++) {
|
||||
for (size_t j = 0; j < s; j++) {
|
||||
if (i >= j) {
|
||||
inout->buf[i*s + j] = 0;
|
||||
} else {
|
||||
inout->buf[i*s + j] = -1e10;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* (Elem-wise) Masking
|
||||
* @param [in1 & out] inout: [N]
|
||||
* @param [in2] mask: [N]
|
||||
|
@ -356,29 +375,6 @@ void add(Activation* inout, Activation* x) {
|
|||
}
|
||||
}
|
||||
|
||||
/* Greedy Max Sampling
|
||||
* @param [in1] in: [s, V]
|
||||
* @return [ret] out: [1]
|
||||
* 's' is the number of tokens in the prompt.
|
||||
* 'V' is the number of vocabulary.
|
||||
*/
|
||||
int greedy_sampling(Activation* in) {
|
||||
|
||||
size_t s = in->shape[0];
|
||||
size_t V = in->shape[1];
|
||||
|
||||
int out = 0;
|
||||
float max = -INFINITY;
|
||||
for (size_t i = 0; i < V; i++) {
|
||||
if (in->buf[(s-1)*V + i] > max) {
|
||||
max = in->buf[(s-1)*V + i];
|
||||
out = i;
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/* Split into QKV
|
||||
* @param [in1] in: [s, H]
|
||||
* @param [out] out: [3, s, H/3]
|
||||
|
@ -423,25 +419,6 @@ void split_head(Activation* in, size_t n_head,
|
|||
}
|
||||
}
|
||||
|
||||
/* Generate mask
|
||||
* @param [in & out] inout: [s, s]
|
||||
* 's' is the number of tokens in the prompt.
|
||||
*/
|
||||
void generate_mask(Activation* inout) {
|
||||
|
||||
size_t s = inout->shape[0];
|
||||
|
||||
for (size_t i = 0; i < s; i++) {
|
||||
for (size_t j = 0; j < s; j++) {
|
||||
if (i >= j) {
|
||||
inout->buf[i*s + j] = 0;
|
||||
} else {
|
||||
inout->buf[i*s + j] = -1e10;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Extract Q, K, V from QKV head
|
||||
* @param [in1] in: [3, n_head, s, H_]
|
||||
* @param [in2] head_idx: [1]
|
||||
|
@ -471,6 +448,7 @@ void extract_qkv(Activation* in, size_t head_idx, size_t n_head,
|
|||
/* Merge each heads
|
||||
* @param [in1] in: [s, H_]
|
||||
* @param [in2] head_idx: [1]
|
||||
* @param [in3] n_head: [1]
|
||||
* @param [out] out: [n_head, s, H_]
|
||||
* 's' is the number of tokens in the prompt.
|
||||
* 'H_' is the hidden dimension/n_head.
|
||||
|
@ -513,6 +491,29 @@ void concat_head(Activation* in,
|
|||
}
|
||||
}
|
||||
|
||||
/* Greedy Max Sampling
|
||||
* @param [in1] in: [s, V]
|
||||
* @return [ret] out: [1]
|
||||
* 's' is the number of tokens in the prompt.
|
||||
* 'V' is the number of vocabulary.
|
||||
*/
|
||||
int greedy_sampling(Activation* in) {
|
||||
|
||||
size_t s = in->shape[0];
|
||||
size_t V = in->shape[1];
|
||||
|
||||
int out = 0;
|
||||
float max = -INFINITY;
|
||||
for (size_t i = 0; i < V; i++) {
|
||||
if (in->buf[(s-1)*V + i] > max) {
|
||||
max = in->buf[(s-1)*V + i];
|
||||
out = i;
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
/* (Position-wise) Feed-Forward Network
|
||||
* @param [in1] in: [input_tokens, HIDDEN_DIM]
|
||||
* @param [in2] mlp1_w: [HIDDEN_DIM, 4*HIDDEN_DIM]
|
||||
|
@ -583,7 +584,7 @@ void mha(Activation* in,
|
|||
|
||||
/* QKV projection:
|
||||
[input_tokens, HIDDEN_DIM] ->
|
||||
[input_tokens, 3*HIDDEN_DIM]) */
|
||||
[input_tokens, 3*HIDDEN_DIM] */
|
||||
linear(in, attn_w, attn_b, mha_qkv_proj_a);
|
||||
|
||||
/* Split into Q, K, V:
|
||||
|
|
|
@ -52,6 +52,7 @@ typedef Tensor Activation;
|
|||
/* [Model Operations] */
|
||||
/* Elem-wise operation */
|
||||
void gelu(Activation* inout);
|
||||
void copy(Activation* in, Activation* out);
|
||||
void add(Activation* inout, Activation* x);
|
||||
void scaling(Activation* inout, float scale);
|
||||
void masking(Activation* inout, Activation* mask);
|
||||
|
@ -62,17 +63,18 @@ void linear(Activation* in, Parameter* w, Parameter* b,
|
|||
void transpose(Activation* in, Activation* out);
|
||||
void split_qkv(Activation* in, Activation* out);
|
||||
void split_head(Activation* in, size_t n_head, Activation* out);
|
||||
void merge_head(Activation* in, size_t head_idx, size_t n_head,
|
||||
Activation* out);
|
||||
void concat_head(Activation* in, Activation* out);
|
||||
/* Other operation */
|
||||
void token_pos_embedding(vector<int> in, Parameter* wte, Parameter* wpe,
|
||||
Activation* out);
|
||||
void softmax(Activation* inout);
|
||||
void layer_norm(Activation* inout, Parameter* gamma, Parameter* beta);
|
||||
void copy(Activation* in, Activation* out);
|
||||
int greedy_sampling(Activation* in);
|
||||
void generate_mask(Activation* inout);
|
||||
void extract_qkv(Activation* in, size_t head_idx, size_t n_head,
|
||||
Activation* q, Activation* k, Activation* v);
|
||||
void merge_head(Activation* in, size_t head_idx, size_t n_head,
|
||||
Activation* out);
|
||||
int greedy_sampling(Activation* in);
|
||||
|
||||
|
||||
/* [Model Construction] */
|
||||
|
|
Loading…
Reference in New Issue