Add holdWebsocket option for faster petals chat

This commit is contained in:
Webifi 2023-08-31 17:25:16 -05:00
parent e991dfd9b7
commit 2e4181bf7e
6 changed files with 216 additions and 106 deletions

View File

@ -21,6 +21,7 @@ export class ChatRequest {
updating: boolean|number = false updating: boolean|number = false
updatingMessage: string = '' updatingMessage: string = ''
controller:AbortController controller:AbortController
providerData: Record<string, any> = {}
setChat (chat: Chat) { setChat (chat: Chat) {
this.chat = chat this.chat = chat

View File

@ -122,6 +122,7 @@ const defaults:ChatSettings = {
systemMessageEnd: '', systemMessageEnd: '',
leadPrompt: '', leadPrompt: '',
repetitionPenalty: 1.1, repetitionPenalty: 1.1,
holdSocket: true,
// useResponseAlteration: false, // useResponseAlteration: false,
// responseAlterations: [], // responseAlterations: [],
isDirty: false isDirty: false
@ -451,6 +452,13 @@ const chatSettingsList: ChatSetting[] = [
type: 'boolean', type: 'boolean',
hide: hideModelSetting hide: hideModelSetting
}, },
{
key: 'holdSocket',
name: 'Continue WebSocket',
title: 'Hold WebSocket connection open and try to re-use for each new chat message. Faster, but message delimitation could get mangled.',
type: 'boolean',
hide: hideModelSetting
},
{ {
key: 'temperature', key: 'temperature',
name: 'Sampling Temperature', name: 'Sampling Temperature',

View File

@ -96,6 +96,7 @@ export type ChatSettings = {
systemMessageStart: string; systemMessageStart: string;
systemMessageEnd: string; systemMessageEnd: string;
repetitionPenalty: number; repetitionPenalty: number;
holdSocket: boolean;
isDirty?: boolean; isDirty?: boolean;
} & Request; } & Request;

View File

@ -154,4 +154,8 @@
return value return value
} }
export const escapeRegex = (string: string): string => {
return string.replace(/[/\-\\^$*+?.()|[\]{}]/g, '\\$&')
}
</script> </script>

View File

@ -20,7 +20,8 @@ const hiddenSettings = {
assistantMessageEnd: true, assistantMessageEnd: true,
systemMessageStart: true, systemMessageStart: true,
systemMessageEnd: true, systemMessageEnd: true,
repetitionPenalty: true repetitionPenalty: true,
holdSocket: true
// leadPrompt: true // leadPrompt: true
} }

View File

@ -5,6 +5,29 @@
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte' import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
import { getModelMaxTokens } from '../../Stats.svelte' import { getModelMaxTokens } from '../../Stats.svelte'
import { updateMessages } from '../../Storage.svelte' import { updateMessages } from '../../Storage.svelte'
import { escapeRegex } from '../../Util.svelte'
const levenshteinDistance = (str1 = '', str2 = '') => {
const track = Array(str2.length + 1).fill(null).map(() =>
Array(str1.length + 1).fill(null))
for (let i = 0; i <= str1.length; i += 1) {
track[0][i] = i
}
for (let j = 0; j <= str2.length; j += 1) {
track[j][0] = j
}
for (let j = 1; j <= str2.length; j += 1) {
for (let i = 1; i <= str1.length; i += 1) {
const indicator = str1[i - 1] === str2[j - 1] ? 0 : 1
track[j][i] = Math.min(
track[j][i - 1] + 1, // deletion
track[j - 1][i] + 1, // insertion
track[j - 1][i - 1] + indicator // substitution
)
}
}
return track[str2.length][str1.length]
}
export const chatRequest = async ( export const chatRequest = async (
request: Request, request: Request,
@ -16,8 +39,10 @@ export const chatRequest = async (
const chatSettings = chat.settings const chatSettings = chat.settings
const model = chatRequest.getModel() const model = chatRequest.getModel()
const modelDetail = getModelDetail(model) const modelDetail = getModelDetail(model)
const ws = new WebSocket(getEndpoint(model))
const signal = chatRequest.controller.signal const signal = chatRequest.controller.signal
const providerData = chatRequest.providerData.petals || {}
chatRequest.providerData.petals = providerData
let ws: WebSocket = providerData.ws
const abortListener = (e:Event) => { const abortListener = (e:Event) => {
chatRequest.updating = false chatRequest.updating = false
chatRequest.updatingMessage = '' chatRequest.updatingMessage = ''
@ -26,9 +51,17 @@ export const chatRequest = async (
ws.close() ws.close()
} }
signal.addEventListener('abort', abortListener) signal.addEventListener('abort', abortListener)
const startSequence = getStartSequence(chat)
let stopSequences = [...new Set(getStopSequence(chat).split(',').filter(s => s.trim()).concat((modelDetail.stop || ['###', '</s>']).slice()))] let stopSequences = [...new Set(getStopSequence(chat).split(',').filter(s => s.trim()).concat((modelDetail.stop || ['###', '</s>']).slice()))]
const stopSequence = '</s>' let stopSequence = stopSequences[0] || '###'
if (startSequence.length) {
const sld = stopSequences.slice()
.filter(s => s === '###' || '</s>' || countTokens(model, s) === 1)
.sort((a, b) => levenshteinDistance(a, startSequence) - levenshteinDistance(b, startSequence))
stopSequence = sld[0] || stopSequence
}
stopSequences.push(stopSequence) stopSequences.push(stopSequence)
const delimiter = getDelimiter(chat) const delimiter = getDelimiter(chat)
const leadPromptSequence = getLeadPrompt(chat) const leadPromptSequence = getLeadPrompt(chat)
if (delimiter) stopSequences.unshift(delimiter.trim()) if (delimiter) stopSequences.unshift(delimiter.trim())
@ -62,13 +95,8 @@ export const chatRequest = async (
const buildMessage = (m: Message): string => { const buildMessage = (m: Message): string => {
return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat) return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
} }
const lastMessage = rMessages[rMessages.length - 1] const buildInputArray = (a) => {
let doLead = true return a.reduce((a, m, i) => {
if (lastMessage && lastMessage.role === 'assistant') {
lastMessage.content = leadPromptSequence + lastMessage.content
doLead = false
}
const inputArray = rMessages.reduce((a, m, i) => {
let c = buildMessage(m) let c = buildMessage(m)
let replace = false let replace = false
const lm = a[a.length - 1] const lm = a[a.length - 1]
@ -102,16 +130,20 @@ export const chatRequest = async (
} }
return a return a
}, [] as Message[]) }, [] as Message[])
const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : ''
const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(delimiter) + leadPrompt
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
const promptTokenCount = countTokens(model, fullPromptInput)
if (promptTokenCount > maxLen) {
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
} }
// update with real count const lastMessage = rMessages[rMessages.length - 1]
chatResponse.setPromptTokenCount(promptTokenCount) let doLead = true
if (lastMessage && lastMessage.role === 'assistant') {
lastMessage.content = leadPromptSequence + lastMessage.content
doLead = false
}
// const inputArray = buildInputArray(rMessages).map(m => m.content)
const lInputArray = buildInputArray(rMessages.slice(0, -1)).map(m => m.content)
const nInputArray = buildInputArray(rMessages.slice(-1)).map(m => m.content)
const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : ''
const lastPrompt = startSequence + lInputArray.join(delimiter)
const nextPrompt = nInputArray.slice(-1).join('') + leadPrompt
// set up the request // set up the request
chatResponse.onFinish(() => { chatResponse.onFinish(() => {
const message = chatResponse.getMessages()[0] const message = chatResponse.getMessages()[0]
@ -124,25 +156,93 @@ export const chatRequest = async (
} }
} }
} }
ws.close() !chatSettings.holdSocket && ws.close()
}) })
ws.onopen = () => {
ws.send(JSON.stringify({ let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
type: 'open_inference_session',
model, let inputPrompt = startSequence
max_length: maxLen
})) const getNewWs = ():Promise<WebSocket> => new Promise<WebSocket>((resolve, reject) => {
ws.onmessage = event => { // console.warn('requesting new ws')
const nws = new WebSocket(getEndpoint(model))
let opened = false
let done = false
nws.onmessage = event => {
if (done) return
done = true
const response = JSON.parse(event.data) const response = JSON.parse(event.data)
if (!response.ok) { if (!response.ok) {
const err = new Error('Error opening socket: ' + response.traceback) const err = new Error('Error opening socket: ' + response.traceback)
chatResponse.updateFromError(err.message) chatResponse.updateFromError(err.message)
console.error(err)
reject(err)
}
nws.onerror = err => {
console.error(err) console.error(err)
throw err throw err
} }
// console.warn('got new ws')
inputPrompt = lastPrompt
providerData.knownBuffer = ''
providerData.ws = nws
resolve(nws)
}
nws.onclose = () => {
chatResponse.updateFromClose()
}
nws.onerror = err => {
if (done) return
done = true
console.error(err)
reject(err)
}
nws.onopen = () => {
if (opened) return
opened = true
const promptTokenCount = countTokens(model, lastPrompt + delimiter + nextPrompt)
if (promptTokenCount > maxLen) {
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
}
// update with real count
chatResponse.setPromptTokenCount(promptTokenCount)
nws.send(JSON.stringify({
type: 'open_inference_session',
model,
max_length: chatSettings.holdSocket ? maxTokens : maxLen
}))
}
})
const wsOpen = (ws && ws.readyState !== WebSocket.CLOSED)
if (!chatSettings.holdSocket || wsOpen) {
const rgxp = new RegExp('(<s>|</s>|\\s|' + escapeRegex(stopSequence) + ')', 'g')
const kb = providerData.knownBuffer.replace(rgxp, '')
const lp = lastPrompt.replace(rgxp, '')
const lm = kb === lp
if (!lm || countTokens(model, providerData.knownBuffer + inputPrompt) >= maxTokens) {
wsOpen && ws.close()
ws = await getNewWs()
}
}
if (!ws || ws.readyState === WebSocket.CLOSED) {
ws = await getNewWs()
}
inputPrompt += delimiter + nextPrompt
providerData.knownBuffer += inputPrompt
// console.log(
// '\n\n*** inputPrompt: ***\n\n',
// inputPrompt
// )
const petalsRequest = { const petalsRequest = {
type: 'generate', type: 'generate',
inputs: fullPromptInput, inputs: inputPrompt,
max_new_tokens: 1, // wait for up to 1 tokens before displaying max_new_tokens: 1, // wait for up to 1 tokens before displaying
stop_sequence: stopSequence, stop_sequence: stopSequence,
do_sample: 1, // enable top p and the like do_sample: 1, // enable top p and the like
@ -152,8 +252,7 @@ export const chatRequest = async (
} as any } as any
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
// Update token count // Update token count
chatResponse.setPromptTokenCount(promptTokenCount) chatResponse.setPromptTokenCount(countTokens(model, providerData.knownBuffer))
ws.send(JSON.stringify(petalsRequest))
ws.onmessage = event => { ws.onmessage = event => {
// Remove updating indicator // Remove updating indicator
chatRequest.updating = 1 // hide indicator, but still signal we're updating chatRequest.updating = 1 // hide indicator, but still signal we're updating
@ -168,6 +267,7 @@ export const chatRequest = async (
chatResponse.updateFromError(err.message) chatResponse.updateFromError(err.message)
throw err throw err
} }
providerData.knownBuffer += response.outputs
chatResponse.updateFromAsyncResponse( chatResponse.updateFromAsyncResponse(
{ {
model, model,
@ -195,21 +295,16 @@ export const chatRequest = async (
response.stop = true response.stop = true
updateMessages(chat.id) updateMessages(chat.id)
chatResponse.finish() chatResponse.finish()
if (ss !== stopSequence) {
providerData.knownBuffer += stopSequence
}
ws.close() ws.close()
} }
} }
} }
} }
} }
} ws.send(JSON.stringify(petalsRequest))
ws.onclose = () => {
chatResponse.updateFromClose()
}
ws.onerror = err => {
console.error(err)
throw err
}
}
return chatResponse return chatResponse
} }
</script> </script>