update token count

This commit is contained in:
Webifi 2023-08-16 01:51:49 -05:00
parent 86f427f62f
commit f4d9774423
1 changed files with 72 additions and 64 deletions

View File

@ -1,7 +1,7 @@
<script context="module" lang="ts">
import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
import { ChatRequest } from '../../ChatRequest.svelte'
import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
import { countTokens, getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
import { getModelMaxTokens } from '../../Stats.svelte'
import { updateMessages } from '../../Storage.svelte'
@ -35,40 +35,7 @@ export const chatRequest = async (
stopSequences = stopSequences.sort((a, b) => b.length - a.length)
const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
const maxTokens = getModelMaxTokens(model)
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
const promptTokenCount = chatResponse.getPromptTokenCount()
if (promptTokenCount > maxLen) {
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
}
chatResponse.onFinish(() => {
const message = chatResponse.getMessages()[0]
if (message) {
for (let i = 0, l = stopSequences.length; i < l; i++) {
const ss = stopSequences[i].trim()
if (message.content.trim().endsWith(ss)) {
message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
updateMessages(chat.id)
}
}
}
chatRequest.updating = false
chatRequest.updatingMessage = ''
ws.close()
})
ws.onopen = () => {
ws.send(JSON.stringify({
type: 'open_inference_session',
model,
max_length: maxLen
}))
ws.onmessage = event => {
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error opening socket: ' + response.traceback)
chatResponse.updateFromError(err.message)
console.error(err)
throw err
}
// Enforce strict order of messages
const fMessages = (request.messages || [] as Message[])
const rMessages = fMessages.reduce((a, m, i) => {
@ -130,9 +97,48 @@ export const chatRequest = async (
return a
}, [] as Message[])
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
const promptTokenCount = countTokens(model, fullPromptInput)
if (promptTokenCount > maxLen) {
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
}
// update with real count
chatResponse.setPromptTokenCount(promptTokenCount)
// set up the request
chatResponse.onFinish(() => {
const message = chatResponse.getMessages()[0]
if (message) {
for (let i = 0, l = stopSequences.length; i < l; i++) {
const ss = stopSequences[i].trim()
if (message.content.trim().endsWith(ss)) {
message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
updateMessages(chat.id)
}
}
}
chatRequest.updating = false
chatRequest.updatingMessage = ''
ws.close()
})
ws.onopen = () => {
ws.send(JSON.stringify({
type: 'open_inference_session',
model,
max_length: maxLen
}))
ws.onmessage = event => {
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error opening socket: ' + response.traceback)
chatResponse.updateFromError(err.message)
console.error(err)
throw err
}
const petalsRequest = {
type: 'generate',
inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt,
inputs: fullPromptInput,
max_new_tokens: 1, // wait for up to 1 tokens before displaying
stop_sequence: stopSequence,
do_sample: 1, // enable top p and the like
@ -141,6 +147,8 @@ export const chatRequest = async (
// repitition_penalty: chatSettings.repititionPenalty
} as any
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
// Update token count
chatResponse.setPromptTokenCount(promptTokenCount)
ws.send(JSON.stringify(petalsRequest))
ws.onmessage = event => {
// Remove updating indicator