update token count
This commit is contained in:
parent
86f427f62f
commit
f4d9774423
|
@ -1,7 +1,7 @@
|
||||||
<script context="module" lang="ts">
|
<script context="module" lang="ts">
|
||||||
import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
|
import { ChatCompletionResponse } from '../../ChatCompletionResponse.svelte'
|
||||||
import { ChatRequest } from '../../ChatRequest.svelte'
|
import { ChatRequest } from '../../ChatRequest.svelte'
|
||||||
import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
|
import { countTokens, getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from '../../Models.svelte'
|
||||||
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
|
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
|
||||||
import { getModelMaxTokens } from '../../Stats.svelte'
|
import { getModelMaxTokens } from '../../Stats.svelte'
|
||||||
import { updateMessages } from '../../Storage.svelte'
|
import { updateMessages } from '../../Storage.svelte'
|
||||||
|
@ -35,40 +35,7 @@ export const chatRequest = async (
|
||||||
stopSequences = stopSequences.sort((a, b) => b.length - a.length)
|
stopSequences = stopSequences.sort((a, b) => b.length - a.length)
|
||||||
const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
|
const stopSequencesC = stopSequences.filter(s => s !== stopSequence)
|
||||||
const maxTokens = getModelMaxTokens(model)
|
const maxTokens = getModelMaxTokens(model)
|
||||||
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
|
||||||
const promptTokenCount = chatResponse.getPromptTokenCount()
|
|
||||||
if (promptTokenCount > maxLen) {
|
|
||||||
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
|
||||||
}
|
|
||||||
chatResponse.onFinish(() => {
|
|
||||||
const message = chatResponse.getMessages()[0]
|
|
||||||
if (message) {
|
|
||||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
|
||||||
const ss = stopSequences[i].trim()
|
|
||||||
if (message.content.trim().endsWith(ss)) {
|
|
||||||
message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
|
|
||||||
updateMessages(chat.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
chatRequest.updating = false
|
|
||||||
chatRequest.updatingMessage = ''
|
|
||||||
ws.close()
|
|
||||||
})
|
|
||||||
ws.onopen = () => {
|
|
||||||
ws.send(JSON.stringify({
|
|
||||||
type: 'open_inference_session',
|
|
||||||
model,
|
|
||||||
max_length: maxLen
|
|
||||||
}))
|
|
||||||
ws.onmessage = event => {
|
|
||||||
const response = JSON.parse(event.data)
|
|
||||||
if (!response.ok) {
|
|
||||||
const err = new Error('Error opening socket: ' + response.traceback)
|
|
||||||
chatResponse.updateFromError(err.message)
|
|
||||||
console.error(err)
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
// Enforce strict order of messages
|
// Enforce strict order of messages
|
||||||
const fMessages = (request.messages || [] as Message[])
|
const fMessages = (request.messages || [] as Message[])
|
||||||
const rMessages = fMessages.reduce((a, m, i) => {
|
const rMessages = fMessages.reduce((a, m, i) => {
|
||||||
|
@ -130,9 +97,48 @@ export const chatRequest = async (
|
||||||
return a
|
return a
|
||||||
}, [] as Message[])
|
}, [] as Message[])
|
||||||
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
|
const leadPrompt = (leadPromptSequence && ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant') ? deliminator + leadPromptSequence : ''
|
||||||
|
const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt
|
||||||
|
|
||||||
|
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
||||||
|
const promptTokenCount = countTokens(model, fullPromptInput)
|
||||||
|
if (promptTokenCount > maxLen) {
|
||||||
|
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
||||||
|
}
|
||||||
|
// update with real count
|
||||||
|
chatResponse.setPromptTokenCount(promptTokenCount)
|
||||||
|
// set up the request
|
||||||
|
chatResponse.onFinish(() => {
|
||||||
|
const message = chatResponse.getMessages()[0]
|
||||||
|
if (message) {
|
||||||
|
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||||
|
const ss = stopSequences[i].trim()
|
||||||
|
if (message.content.trim().endsWith(ss)) {
|
||||||
|
message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
|
||||||
|
updateMessages(chat.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chatRequest.updating = false
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
ws.close()
|
||||||
|
})
|
||||||
|
ws.onopen = () => {
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
type: 'open_inference_session',
|
||||||
|
model,
|
||||||
|
max_length: maxLen
|
||||||
|
}))
|
||||||
|
ws.onmessage = event => {
|
||||||
|
const response = JSON.parse(event.data)
|
||||||
|
if (!response.ok) {
|
||||||
|
const err = new Error('Error opening socket: ' + response.traceback)
|
||||||
|
chatResponse.updateFromError(err.message)
|
||||||
|
console.error(err)
|
||||||
|
throw err
|
||||||
|
}
|
||||||
const petalsRequest = {
|
const petalsRequest = {
|
||||||
type: 'generate',
|
type: 'generate',
|
||||||
inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt,
|
inputs: fullPromptInput,
|
||||||
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
||||||
stop_sequence: stopSequence,
|
stop_sequence: stopSequence,
|
||||||
do_sample: 1, // enable top p and the like
|
do_sample: 1, // enable top p and the like
|
||||||
|
@ -141,6 +147,8 @@ export const chatRequest = async (
|
||||||
// repitition_penalty: chatSettings.repititionPenalty
|
// repitition_penalty: chatSettings.repititionPenalty
|
||||||
} as any
|
} as any
|
||||||
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
|
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
|
||||||
|
// Update token count
|
||||||
|
chatResponse.setPromptTokenCount(promptTokenCount)
|
||||||
ws.send(JSON.stringify(petalsRequest))
|
ws.send(JSON.stringify(petalsRequest))
|
||||||
ws.onmessage = event => {
|
ws.onmessage = event => {
|
||||||
// Remove updating indicator
|
// Remove updating indicator
|
||||||
|
|
Loading…
Reference in New Issue