update token count
This commit is contained in:
parent
86f427f62f
commit
f4d9774423
|
@ -1,7 +1,7 @@
|
||||||
<script context="module" lang="ts">
|
<script context="module" lang="ts">
|
||||||
import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
|
import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
|
||||||
import { ChatRequest } from '../../ChatRequest.svelte'
|
import { ChatRequest } from '../../ChatRequest.svelte'
|
||||||
import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
|
import { countTokens, getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
|
||||||
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
|
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
|
||||||
import { getModelMaxTokens } from '../../Stats.svelte'
|
import { getModelMaxTokens } from '../../Stats.svelte'
|
||||||
import { updateMessages } from '../../Storage.svelte'
|
import { updateMessages } from '../../Storage.svelte'
|
||||||
|
@ -35,11 +35,78 @@ export const chatRequest = async (
|
||||||
stopSequences = stopSequences.sort((a, b) => b.length - a.length)
|
stopSequences = stopSequences.sort((a, b) => b.length - a.length)
|
||||||
const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
|
const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
|
||||||
const maxTokens = getModelMaxTokens(model)
|
const maxTokens = getModelMaxTokens(model)
|
||||||
|
|
||||||
|
// Enforce strict order of messages
|
||||||
|
const fMessages = (request.messages || [] as Message[])
|
||||||
|
const rMessages = fMessages.reduce((a, m, i) => {
|
||||||
|
a.push(m)
|
||||||
|
const nm = fMessages[i + 1]
|
||||||
|
if (m.role === 'system' && (!nm || nm.role !== 'user')) {
|
||||||
|
const nc = {
|
||||||
|
role: 'user',
|
||||||
|
content: ''
|
||||||
|
} as Message
|
||||||
|
a.push(nc)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
},
|
||||||
|
[] as Message[])
|
||||||
|
// make sure top_p and temperature are set the way we need
|
||||||
|
let temperature = request.temperature
|
||||||
|
if (temperature === undefined || isNaN(temperature as any)) temperature = 1
|
||||||
|
if (!temperature || temperature <= 0) temperature = 0.01
|
||||||
|
let topP = request.top_p
|
||||||
|
if (topP === undefined || isNaN(topP as any)) topP = 1
|
||||||
|
if (!topP || topP <= 0) topP = 0.01
|
||||||
|
// build the message array
|
||||||
|
const buildMessage = (m: Message): string => {
|
||||||
|
return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
|
||||||
|
}
|
||||||
|
const inputArray = rMessages.reduce((a, m, i) => {
|
||||||
|
let c = buildMessage(m)
|
||||||
|
let replace = false
|
||||||
|
const lm = a[a.length - 1]
|
||||||
|
// Merge content if needed
|
||||||
|
if (lm) {
|
||||||
|
if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
|
||||||
|
c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
|
||||||
|
replace = true
|
||||||
|
} else {
|
||||||
|
c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
|
||||||
|
}
|
||||||
|
if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
|
||||||
|
c = c.replaceAll('[[USER_PROMPT]]', lm.content)
|
||||||
|
replace = true
|
||||||
|
} else {
|
||||||
|
c = c.replaceAll('[[USER_PROMPT]]', '')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Clean up merge fields on last
|
||||||
|
if (!rMessages[i + 1]) {
|
||||||
|
c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
|
||||||
|
}
|
||||||
|
const result = {
|
||||||
|
role: m.role,
|
||||||
|
content: c.trim()
|
||||||
|
} as Message
|
||||||
|
if (replace) {
|
||||||
|
a[a.length - 1] = result
|
||||||
|
} else {
|
||||||
|
a.push(result)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}, [] as Message[])
|
||||||
|
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
|
||||||
|
const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt
|
||||||
|
|
||||||
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
||||||
const promptTokenCount = chatResponse.getPromptTokenCount()
|
const promptTokenCount = countTokens(model, fullPromptInput)
|
||||||
if (promptTokenCount > maxLen) {
|
if (promptTokenCount > maxLen) {
|
||||||
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
||||||
}
|
}
|
||||||
|
// update with real count
|
||||||
|
chatResponse.setPromptTokenCount(promptTokenCount)
|
||||||
|
// set up the request
|
||||||
chatResponse.onFinish(() => {
|
chatResponse.onFinish(() => {
|
||||||
const message = chatResponse.getMessages()[0]
|
const message = chatResponse.getMessages()[0]
|
||||||
if (message) {
|
if (message) {
|
||||||
|
@ -69,70 +136,9 @@ export const chatRequest = async (
|
||||||
console.error(err)
|
console.error(err)
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
// Enforce strict order of messages
|
|
||||||
const fMessages = (request.messages || [] as Message[])
|
|
||||||
const rMessages = fMessages.reduce((a, m, i) => {
|
|
||||||
a.push(m)
|
|
||||||
const nm = fMessages[i + 1]
|
|
||||||
if (m.role === 'system' && (!nm || nm.role !== 'user')) {
|
|
||||||
const nc = {
|
|
||||||
role: 'user',
|
|
||||||
content: ''
|
|
||||||
} as Message
|
|
||||||
a.push(nc)
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
},
|
|
||||||
[] as Message[])
|
|
||||||
// make sure top_p and temperature are set the way we need
|
|
||||||
let temperature = request.temperature
|
|
||||||
if (temperature === undefined || isNaN(temperature as any)) temperature = 1
|
|
||||||
if (!temperature || temperature <= 0) temperature = 0.01
|
|
||||||
let topP = request.top_p
|
|
||||||
if (topP === undefined || isNaN(topP as any)) topP = 1
|
|
||||||
if (!topP || topP <= 0) topP = 0.01
|
|
||||||
// build the message array
|
|
||||||
const buildMessage = (m: Message): string => {
|
|
||||||
return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
|
|
||||||
}
|
|
||||||
const inputArray = rMessages.reduce((a, m, i) => {
|
|
||||||
let c = buildMessage(m)
|
|
||||||
let replace = false
|
|
||||||
const lm = a[a.length - 1]
|
|
||||||
// Merge content if needed
|
|
||||||
if (lm) {
|
|
||||||
if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
|
|
||||||
c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
|
|
||||||
replace = true
|
|
||||||
} else {
|
|
||||||
c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
|
|
||||||
}
|
|
||||||
if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
|
|
||||||
c = c.replaceAll('[[USER_PROMPT]]', lm.content)
|
|
||||||
replace = true
|
|
||||||
} else {
|
|
||||||
c = c.replaceAll('[[USER_PROMPT]]', '')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Clean up merge fields on last
|
|
||||||
if (!rMessages[i + 1]) {
|
|
||||||
c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
|
|
||||||
}
|
|
||||||
const result = {
|
|
||||||
role: m.role,
|
|
||||||
content: c.trim()
|
|
||||||
} as Message
|
|
||||||
if (replace) {
|
|
||||||
a[a.length - 1] = result
|
|
||||||
} else {
|
|
||||||
a.push(result)
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}, [] as Message[])
|
|
||||||
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
|
|
||||||
const petalsRequest = {
|
const petalsRequest = {
|
||||||
type: 'generate',
|
type: 'generate',
|
||||||
inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt,
|
inputs: fullPromptInput,
|
||||||
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
||||||
stop_sequence: stopSequence,
|
stop_sequence: stopSequence,
|
||||||
do_sample: 1, // enable top p and the like
|
do_sample: 1, // enable top p and the like
|
||||||
|
@ -141,6 +147,8 @@ export const chatRequest = async (
|
||||||
// repitition_penalty: chatSettings.repititionPenalty
|
// repitition_penalty: chatSettings.repititionPenalty
|
||||||
} as any
|
} as any
|
||||||
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
|
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
|
||||||
|
// Update token count
|
||||||
|
chatResponse.setPromptTokenCount(promptTokenCount)
|
||||||
ws.send(JSON.stringify(petalsRequest))
|
ws.send(JSON.stringify(petalsRequest))
|
||||||
ws.onmessage = event => {
|
ws.onmessage = event => {
|
||||||
// Remove updating indicator
|
// Remove updating indicator
|
||||||
|
|
Loading…
Reference in New Issue