update token count

This commit is contained in:
Webifi 2023-08-16 01:51:49 -05:00
parent 86f427f62f
commit f4d9774423
1 changed files with 72 additions and 64 deletions

View File

@ -1,7 +1,7 @@
<script context="module" lang="ts"> <script context="module" lang="ts">
import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte' import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
import { ChatRequest } from '../../ChatRequest.svelte' import { ChatRequest } from '../../ChatRequest.svelte'
import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte' import { countTokens, getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte' import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
import { getModelMaxTokens } from '../../Stats.svelte' import { getModelMaxTokens } from '../../Stats.svelte'
import { updateMessages } from '../../Storage.svelte' import { updateMessages } from '../../Storage.svelte'
@ -35,11 +35,78 @@ export const chatRequest = async (
stopSequences = stopSequences.sort((a, b) => b.length - a.length) stopSequences = stopSequences.sort((a, b) => b.length - a.length)
const stopSequencesC = stopSequences.filter(s => s !== stopSequence) const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
const maxTokens = getModelMaxTokens(model) const maxTokens = getModelMaxTokens(model)
// Enforce strict order of messages
const fMessages = (request.messages || [] as Message[])
const rMessages = fMessages.reduce((a, m, i) => {
a.push(m)
const nm = fMessages[i + 1]
if (m.role === 'system' && (!nm || nm.role !== 'user')) {
const nc = {
role: 'user',
content: ''
} as Message
a.push(nc)
}
return a
},
[] as Message[])
// make sure top_p and temperature are set the way we need
let temperature = request.temperature
if (temperature === undefined || isNaN(temperature as any)) temperature = 1
if (!temperature || temperature <= 0) temperature = 0.01
let topP = request.top_p
if (topP === undefined || isNaN(topP as any)) topP = 1
if (!topP || topP <= 0) topP = 0.01
// build the message array
const buildMessage = (m: Message): string => {
return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
}
const inputArray = rMessages.reduce((a, m, i) => {
let c = buildMessage(m)
let replace = false
const lm = a[a.length - 1]
// Merge content if needed
if (lm) {
if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
replace = true
} else {
c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
}
if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
c = c.replaceAll('[[USER_PROMPT]]', lm.content)
replace = true
} else {
c = c.replaceAll('[[USER_PROMPT]]', '')
}
}
// Clean up merge fields on last
if (!rMessages[i + 1]) {
c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
}
const result = {
role: m.role,
content: c.trim()
} as Message
if (replace) {
a[a.length - 1] = result
} else {
a.push(result)
}
return a
}, [] as Message[])
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens) let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
const promptTokenCount = chatResponse.getPromptTokenCount() const promptTokenCount = countTokens(model, fullPromptInput)
if (promptTokenCount > maxLen) { if (promptTokenCount > maxLen) {
maxLen = Math.min(maxLen + promptTokenCount, maxTokens) maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
} }
// update with real count
chatResponse.setPromptTokenCount(promptTokenCount)
// set up the request
chatResponse.onFinish(() => { chatResponse.onFinish(() => {
const message = chatResponse.getMessages()[0] const message = chatResponse.getMessages()[0]
if (message) { if (message) {
@ -69,70 +136,9 @@ export const chatRequest = async (
console.error(err) console.error(err)
throw err throw err
} }
// Enforce strict order of messages
const fMessages = (request.messages || [] as Message[])
const rMessages = fMessages.reduce((a, m, i) => {
a.push(m)
const nm = fMessages[i + 1]
if (m.role === 'system' && (!nm || nm.role !== 'user')) {
const nc = {
role: 'user',
content: ''
} as Message
a.push(nc)
}
return a
},
[] as Message[])
// make sure top_p and temperature are set the way we need
let temperature = request.temperature
if (temperature === undefined || isNaN(temperature as any)) temperature = 1
if (!temperature || temperature <= 0) temperature = 0.01
let topP = request.top_p
if (topP === undefined || isNaN(topP as any)) topP = 1
if (!topP || topP <= 0) topP = 0.01
// build the message array
const buildMessage = (m: Message): string => {
return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
}
const inputArray = rMessages.reduce((a, m, i) => {
let c = buildMessage(m)
let replace = false
const lm = a[a.length - 1]
// Merge content if needed
if (lm) {
if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
replace = true
} else {
c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
}
if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
c = c.replaceAll('[[USER_PROMPT]]', lm.content)
replace = true
} else {
c = c.replaceAll('[[USER_PROMPT]]', '')
}
}
// Clean up merge fields on last
if (!rMessages[i + 1]) {
c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
}
const result = {
role: m.role,
content: c.trim()
} as Message
if (replace) {
a[a.length - 1] = result
} else {
a.push(result)
}
return a
}, [] as Message[])
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
const petalsRequest = { const petalsRequest = {
type: 'generate', type: 'generate',
inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt, inputs: fullPromptInput,
max_new_tokens: 1, // wait for up to 1 tokens before displaying max_new_tokens: 1, // wait for up to 1 tokens before displaying
stop_sequence: stopSequence, stop_sequence: stopSequence,
do_sample: 1, // enable top p and the like do_sample: 1, // enable top p and the like
@ -141,6 +147,8 @@ export const chatRequest = async (
// repitition_penalty: chatSettings.repititionPenalty // repitition_penalty: chatSettings.repititionPenalty
} as any } as any
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
// Update token count
chatResponse.setPromptTokenCount(promptTokenCount)
ws.send(JSON.stringify(petalsRequest)) ws.send(JSON.stringify(petalsRequest))
ws.onmessage = event => { ws.onmessage = event => {
// Remove updating indicator // Remove updating indicator