update token count
This commit is contained in:
parent
86f427f62f
commit
f4d9774423
|
@ -1,7 +1,7 @@
|
|||
<script context="module" lang="ts">
|
||||
import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
|
||||
import { ChatRequest } from '../../ChatRequest.svelte'
|
||||
import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
|
||||
import { countTokens, getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
|
||||
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
|
||||
import { getModelMaxTokens } from '../../Stats.svelte'
|
||||
import { updateMessages } from '../../Storage.svelte'
|
||||
|
@ -35,40 +35,7 @@ export const chatRequest = async (
|
|||
stopSequences = stopSequences.sort((a, b) => b.length - a.length)
|
||||
const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
|
||||
const maxTokens = getModelMaxTokens(model)
|
||||
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
||||
const promptTokenCount = chatResponse.getPromptTokenCount()
|
||||
if (promptTokenCount > maxLen) {
|
||||
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
||||
}
|
||||
chatResponse.onFinish(() => {
|
||||
const message = chatResponse.getMessages()[0]
|
||||
if (message) {
|
||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||
const ss = stopSequences[i].trim()
|
||||
if (message.content.trim().endsWith(ss)) {
|
||||
message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
|
||||
updateMessages(chat.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
ws.close()
|
||||
})
|
||||
ws.onopen = () => {
|
||||
ws.send(JSON.stringify({
|
||||
type: 'open_inference_session',
|
||||
model,
|
||||
max_length: maxLen
|
||||
}))
|
||||
ws.onmessage = event => {
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error opening socket: ' + response.traceback)
|
||||
chatResponse.updateFromError(err.message)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
|
||||
// Enforce strict order of messages
|
||||
const fMessages = (request.messages || [] as Message[])
|
||||
const rMessages = fMessages.reduce((a, m, i) => {
|
||||
|
@ -130,9 +97,48 @@ export const chatRequest = async (
|
|||
return a
|
||||
}, [] as Message[])
|
||||
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
|
||||
const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt
|
||||
|
||||
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
||||
const promptTokenCount = countTokens(model, fullPromptInput)
|
||||
if (promptTokenCount > maxLen) {
|
||||
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
||||
}
|
||||
// update with real count
|
||||
chatResponse.setPromptTokenCount(promptTokenCount)
|
||||
// set up the request
|
||||
chatResponse.onFinish(() => {
|
||||
const message = chatResponse.getMessages()[0]
|
||||
if (message) {
|
||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||
const ss = stopSequences[i].trim()
|
||||
if (message.content.trim().endsWith(ss)) {
|
||||
message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
|
||||
updateMessages(chat.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
chatRequest.updating = false
|
||||
chatRequest.updatingMessage = ''
|
||||
ws.close()
|
||||
})
|
||||
ws.onopen = () => {
|
||||
ws.send(JSON.stringify({
|
||||
type: 'open_inference_session',
|
||||
model,
|
||||
max_length: maxLen
|
||||
}))
|
||||
ws.onmessage = event => {
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error opening socket: ' + response.traceback)
|
||||
chatResponse.updateFromError(err.message)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
const petalsRequest = {
|
||||
type: 'generate',
|
||||
inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt,
|
||||
inputs: fullPromptInput,
|
||||
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
||||
stop_sequence: stopSequence,
|
||||
do_sample: 1, // enable top p and the like
|
||||
|
@ -141,6 +147,8 @@ export const chatRequest = async (
|
|||
// repitition_penalty: chatSettings.repititionPenalty
|
||||
} as any
|
||||
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
|
||||
// Update token count
|
||||
chatResponse.setPromptTokenCount(promptTokenCount)
|
||||
ws.send(JSON.stringify(petalsRequest))
|
||||
ws.onmessage = event => {
|
||||
// Remove updating indicator
|
||||
|
|
Loading…
Reference in New Issue