Add holdWebsocket option for faster petals chat
This commit is contained in:
parent
e991dfd9b7
commit
2e4181bf7e
|
@ -21,6 +21,7 @@ export class ChatRequest {
|
||||||
updating: boolean|number = false
|
updating: boolean|number = false
|
||||||
updatingMessage: string = ''
|
updatingMessage: string = ''
|
||||||
controller:AbortController
|
controller:AbortController
|
||||||
|
providerData: Record<string, any> = {}
|
||||||
|
|
||||||
setChat (chat: Chat) {
|
setChat (chat: Chat) {
|
||||||
this.chat = chat
|
this.chat = chat
|
||||||
|
|
|
@ -122,6 +122,7 @@ const defaults:ChatSettings = {
|
||||||
systemMessageEnd: '',
|
systemMessageEnd: '',
|
||||||
leadPrompt: '',
|
leadPrompt: '',
|
||||||
repetitionPenalty: 1.1,
|
repetitionPenalty: 1.1,
|
||||||
|
holdSocket: true,
|
||||||
// useResponseAlteration: false,
|
// useResponseAlteration: false,
|
||||||
// responseAlterations: [],
|
// responseAlterations: [],
|
||||||
isDirty: false
|
isDirty: false
|
||||||
|
@ -451,6 +452,13 @@ const chatSettingsList: ChatSetting[] = [
|
||||||
type: 'boolean',
|
type: 'boolean',
|
||||||
hide: hideModelSetting
|
hide: hideModelSetting
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 'holdSocket',
|
||||||
|
name: 'Continue WebSocket',
|
||||||
|
title: 'Hold WebSocket connection open and try to re-use for each new chat message. Faster, but message delimitation could get mangled.',
|
||||||
|
type: 'boolean',
|
||||||
|
hide: hideModelSetting
|
||||||
|
},
|
||||||
{
|
{
|
||||||
key: 'temperature',
|
key: 'temperature',
|
||||||
name: 'Sampling Temperature',
|
name: 'Sampling Temperature',
|
||||||
|
|
|
@ -96,6 +96,7 @@ export type ChatSettings = {
|
||||||
systemMessageStart: string;
|
systemMessageStart: string;
|
||||||
systemMessageEnd: string;
|
systemMessageEnd: string;
|
||||||
repetitionPenalty: number;
|
repetitionPenalty: number;
|
||||||
|
holdSocket: boolean;
|
||||||
isDirty?: boolean;
|
isDirty?: boolean;
|
||||||
} & Request;
|
} & Request;
|
||||||
|
|
||||||
|
|
|
@ -154,4 +154,8 @@
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const escapeRegex = (string: string): string => {
|
||||||
|
return string.replace(/[/\-\\^$*+?.()|[\]{}]/g, '\\$&')
|
||||||
|
}
|
||||||
|
|
||||||
</script>
|
</script>
|
|
@ -20,7 +20,8 @@ const hiddenSettings = {
|
||||||
assistantMessageEnd: true,
|
assistantMessageEnd: true,
|
||||||
systemMessageStart: true,
|
systemMessageStart: true,
|
||||||
systemMessageEnd: true,
|
systemMessageEnd: true,
|
||||||
repetitionPenalty: true
|
repetitionPenalty: true,
|
||||||
|
holdSocket: true
|
||||||
// leadPrompt: true
|
// leadPrompt: true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,29 @@
|
||||||
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
|
import type { ChatCompletionOpts, Message, Request } from '../../Types.svelte'
|
||||||
import { getModelMaxTokens } from '../../Stats.svelte'
|
import { getModelMaxTokens } from '../../Stats.svelte'
|
||||||
import { updateMessages } from '../../Storage.svelte'
|
import { updateMessages } from '../../Storage.svelte'
|
||||||
|
import { escapeRegex } from '../../Util.svelte'
|
||||||
|
|
||||||
|
const levenshteinDistance = (str1 = '', str2 = '') => {
|
||||||
|
const track = Array(str2.length + 1).fill(null).map(() =>
|
||||||
|
Array(str1.length + 1).fill(null))
|
||||||
|
for (let i = 0; i <= str1.length; i += 1) {
|
||||||
|
track[0][i] = i
|
||||||
|
}
|
||||||
|
for (let j = 0; j <= str2.length; j += 1) {
|
||||||
|
track[j][0] = j
|
||||||
|
}
|
||||||
|
for (let j = 1; j <= str2.length; j += 1) {
|
||||||
|
for (let i = 1; i <= str1.length; i += 1) {
|
||||||
|
const indicator = str1[i - 1] === str2[j - 1] ? 0 : 1
|
||||||
|
track[j][i] = Math.min(
|
||||||
|
track[j][i - 1] + 1, // deletion
|
||||||
|
track[j - 1][i] + 1, // insertion
|
||||||
|
track[j - 1][i - 1] + indicator // substitution
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return track[str2.length][str1.length]
|
||||||
|
}
|
||||||
|
|
||||||
export const chatRequest = async (
|
export const chatRequest = async (
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -16,8 +39,10 @@ export const chatRequest = async (
|
||||||
const chatSettings = chat.settings
|
const chatSettings = chat.settings
|
||||||
const model = chatRequest.getModel()
|
const model = chatRequest.getModel()
|
||||||
const modelDetail = getModelDetail(model)
|
const modelDetail = getModelDetail(model)
|
||||||
const ws = new WebSocket(getEndpoint(model))
|
|
||||||
const signal = chatRequest.controller.signal
|
const signal = chatRequest.controller.signal
|
||||||
|
const providerData = chatRequest.providerData.petals || {}
|
||||||
|
chatRequest.providerData.petals = providerData
|
||||||
|
let ws: WebSocket = providerData.ws
|
||||||
const abortListener = (e:Event) => {
|
const abortListener = (e:Event) => {
|
||||||
chatRequest.updating = false
|
chatRequest.updating = false
|
||||||
chatRequest.updatingMessage = ''
|
chatRequest.updatingMessage = ''
|
||||||
|
@ -26,9 +51,17 @@ export const chatRequest = async (
|
||||||
ws.close()
|
ws.close()
|
||||||
}
|
}
|
||||||
signal.addEventListener('abort', abortListener)
|
signal.addEventListener('abort', abortListener)
|
||||||
|
const startSequence = getStartSequence(chat)
|
||||||
let stopSequences = [...new Set(getStopSequence(chat).split(',').filter(s => s.trim()).concat((modelDetail.stop || ['###', '</s>']).slice()))]
|
let stopSequences = [...new Set(getStopSequence(chat).split(',').filter(s => s.trim()).concat((modelDetail.stop || ['###', '</s>']).slice()))]
|
||||||
const stopSequence = '</s>'
|
let stopSequence = stopSequences[0] || '###'
|
||||||
|
if (startSequence.length) {
|
||||||
|
const sld = stopSequences.slice()
|
||||||
|
.filter(s => s === '###' || '</s>' || countTokens(model, s) === 1)
|
||||||
|
.sort((a, b) => levenshteinDistance(a, startSequence) - levenshteinDistance(b, startSequence))
|
||||||
|
stopSequence = sld[0] || stopSequence
|
||||||
|
}
|
||||||
stopSequences.push(stopSequence)
|
stopSequences.push(stopSequence)
|
||||||
|
|
||||||
const delimiter = getDelimiter(chat)
|
const delimiter = getDelimiter(chat)
|
||||||
const leadPromptSequence = getLeadPrompt(chat)
|
const leadPromptSequence = getLeadPrompt(chat)
|
||||||
if (delimiter) stopSequences.unshift(delimiter.trim())
|
if (delimiter) stopSequences.unshift(delimiter.trim())
|
||||||
|
@ -62,56 +95,55 @@ export const chatRequest = async (
|
||||||
const buildMessage = (m: Message): string => {
|
const buildMessage = (m: Message): string => {
|
||||||
return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
|
return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
|
||||||
}
|
}
|
||||||
|
const buildInputArray = (a) => {
|
||||||
|
return a.reduce((a, m, i) => {
|
||||||
|
let c = buildMessage(m)
|
||||||
|
let replace = false
|
||||||
|
const lm = a[a.length - 1]
|
||||||
|
// Merge content if needed
|
||||||
|
if (lm) {
|
||||||
|
if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
|
||||||
|
c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
|
||||||
|
replace = true
|
||||||
|
} else {
|
||||||
|
c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
|
||||||
|
}
|
||||||
|
if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
|
||||||
|
c = c.replaceAll('[[USER_PROMPT]]', lm.content)
|
||||||
|
replace = true
|
||||||
|
} else {
|
||||||
|
c = c.replaceAll('[[USER_PROMPT]]', '')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Clean up merge fields on last
|
||||||
|
if (!rMessages[i + 1]) {
|
||||||
|
c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
|
||||||
|
}
|
||||||
|
const result = {
|
||||||
|
role: m.role,
|
||||||
|
content: c.trim()
|
||||||
|
} as Message
|
||||||
|
if (replace) {
|
||||||
|
a[a.length - 1] = result
|
||||||
|
} else {
|
||||||
|
a.push(result)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}, [] as Message[])
|
||||||
|
}
|
||||||
const lastMessage = rMessages[rMessages.length - 1]
|
const lastMessage = rMessages[rMessages.length - 1]
|
||||||
let doLead = true
|
let doLead = true
|
||||||
if (lastMessage && lastMessage.role === 'assistant') {
|
if (lastMessage && lastMessage.role === 'assistant') {
|
||||||
lastMessage.content = leadPromptSequence + lastMessage.content
|
lastMessage.content = leadPromptSequence + lastMessage.content
|
||||||
doLead = false
|
doLead = false
|
||||||
}
|
}
|
||||||
const inputArray = rMessages.reduce((a, m, i) => {
|
// const inputArray = buildInputArray(rMessages).map(m => m.content)
|
||||||
let c = buildMessage(m)
|
const lInputArray = buildInputArray(rMessages.slice(0, -1)).map(m => m.content)
|
||||||
let replace = false
|
const nInputArray = buildInputArray(rMessages.slice(-1)).map(m => m.content)
|
||||||
const lm = a[a.length - 1]
|
|
||||||
// Merge content if needed
|
|
||||||
if (lm) {
|
|
||||||
if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
|
|
||||||
c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
|
|
||||||
replace = true
|
|
||||||
} else {
|
|
||||||
c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
|
|
||||||
}
|
|
||||||
if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
|
|
||||||
c = c.replaceAll('[[USER_PROMPT]]', lm.content)
|
|
||||||
replace = true
|
|
||||||
} else {
|
|
||||||
c = c.replaceAll('[[USER_PROMPT]]', '')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Clean up merge fields on last
|
|
||||||
if (!rMessages[i + 1]) {
|
|
||||||
c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
|
|
||||||
}
|
|
||||||
const result = {
|
|
||||||
role: m.role,
|
|
||||||
content: c.trim()
|
|
||||||
} as Message
|
|
||||||
if (replace) {
|
|
||||||
a[a.length - 1] = result
|
|
||||||
} else {
|
|
||||||
a.push(result)
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}, [] as Message[])
|
|
||||||
const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : ''
|
const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : ''
|
||||||
const fullPromptInput = getStartSequence(chat) + inputArray.map(m => m.content).join(delimiter) + leadPrompt
|
const lastPrompt = startSequence + lInputArray.join(delimiter)
|
||||||
|
const nextPrompt = nInputArray.slice(-1).join('') + leadPrompt
|
||||||
|
|
||||||
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
|
||||||
const promptTokenCount = countTokens(model, fullPromptInput)
|
|
||||||
if (promptTokenCount > maxLen) {
|
|
||||||
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
|
||||||
}
|
|
||||||
// update with real count
|
|
||||||
chatResponse.setPromptTokenCount(promptTokenCount)
|
|
||||||
// set up the request
|
// set up the request
|
||||||
chatResponse.onFinish(() => {
|
chatResponse.onFinish(() => {
|
||||||
const message = chatResponse.getMessages()[0]
|
const message = chatResponse.getMessages()[0]
|
||||||
|
@ -124,51 +156,119 @@ export const chatRequest = async (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ws.close()
|
!chatSettings.holdSocket && ws.close()
|
||||||
})
|
})
|
||||||
ws.onopen = () => {
|
|
||||||
ws.send(JSON.stringify({
|
let maxLen = Math.min(opts.maxTokens || chatSettings.max_tokens || maxTokens, maxTokens)
|
||||||
type: 'open_inference_session',
|
|
||||||
model,
|
let inputPrompt = startSequence
|
||||||
max_length: maxLen
|
|
||||||
}))
|
const getNewWs = ():Promise<WebSocket> => new Promise<WebSocket>((resolve, reject) => {
|
||||||
ws.onmessage = event => {
|
// console.warn('requesting new ws')
|
||||||
|
const nws = new WebSocket(getEndpoint(model))
|
||||||
|
let opened = false
|
||||||
|
let done = false
|
||||||
|
nws.onmessage = event => {
|
||||||
|
if (done) return
|
||||||
|
done = true
|
||||||
const response = JSON.parse(event.data)
|
const response = JSON.parse(event.data)
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const err = new Error('Error opening socket: ' + response.traceback)
|
const err = new Error('Error opening socket: ' + response.traceback)
|
||||||
chatResponse.updateFromError(err.message)
|
chatResponse.updateFromError(err.message)
|
||||||
|
console.error(err)
|
||||||
|
reject(err)
|
||||||
|
}
|
||||||
|
nws.onerror = err => {
|
||||||
console.error(err)
|
console.error(err)
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
const petalsRequest = {
|
// console.warn('got new ws')
|
||||||
type: 'generate',
|
inputPrompt = lastPrompt
|
||||||
inputs: fullPromptInput,
|
providerData.knownBuffer = ''
|
||||||
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
providerData.ws = nws
|
||||||
stop_sequence: stopSequence,
|
resolve(nws)
|
||||||
do_sample: 1, // enable top p and the like
|
}
|
||||||
temperature,
|
nws.onclose = () => {
|
||||||
top_p: topP,
|
chatResponse.updateFromClose()
|
||||||
repetition_penalty: chatSettings.repetitionPenalty
|
}
|
||||||
} as any
|
nws.onerror = err => {
|
||||||
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
|
if (done) return
|
||||||
// Update token count
|
done = true
|
||||||
|
console.error(err)
|
||||||
|
reject(err)
|
||||||
|
}
|
||||||
|
nws.onopen = () => {
|
||||||
|
if (opened) return
|
||||||
|
opened = true
|
||||||
|
const promptTokenCount = countTokens(model, lastPrompt + delimiter + nextPrompt)
|
||||||
|
if (promptTokenCount > maxLen) {
|
||||||
|
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
|
||||||
|
}
|
||||||
|
// update with real count
|
||||||
chatResponse.setPromptTokenCount(promptTokenCount)
|
chatResponse.setPromptTokenCount(promptTokenCount)
|
||||||
ws.send(JSON.stringify(petalsRequest))
|
nws.send(JSON.stringify({
|
||||||
ws.onmessage = event => {
|
type: 'open_inference_session',
|
||||||
// Remove updating indicator
|
model,
|
||||||
chatRequest.updating = 1 // hide indicator, but still signal we're updating
|
max_length: chatSettings.holdSocket ? maxTokens : maxLen
|
||||||
chatRequest.updatingMessage = ''
|
}))
|
||||||
const response = JSON.parse(event.data)
|
}
|
||||||
if (!response.ok) {
|
})
|
||||||
if (response.traceback.includes('Maximum length exceeded')) {
|
|
||||||
return chatResponse.finish('length')
|
const wsOpen = (ws && ws.readyState !== WebSocket.CLOSED)
|
||||||
}
|
|
||||||
const err = new Error('Error in response: ' + response.traceback)
|
if (!chatSettings.holdSocket || wsOpen) {
|
||||||
console.error(err)
|
const rgxp = new RegExp('(<s>|</s>|\\s|' + escapeRegex(stopSequence) + ')', 'g')
|
||||||
chatResponse.updateFromError(err.message)
|
const kb = providerData.knownBuffer.replace(rgxp, '')
|
||||||
throw err
|
const lp = lastPrompt.replace(rgxp, '')
|
||||||
}
|
const lm = kb === lp
|
||||||
chatResponse.updateFromAsyncResponse(
|
if (!lm || countTokens(model, providerData.knownBuffer + inputPrompt) >= maxTokens) {
|
||||||
|
wsOpen && ws.close()
|
||||||
|
ws = await getNewWs()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ws || ws.readyState === WebSocket.CLOSED) {
|
||||||
|
ws = await getNewWs()
|
||||||
|
}
|
||||||
|
|
||||||
|
inputPrompt += delimiter + nextPrompt
|
||||||
|
providerData.knownBuffer += inputPrompt
|
||||||
|
|
||||||
|
// console.log(
|
||||||
|
// '\n\n*** inputPrompt: ***\n\n',
|
||||||
|
// inputPrompt
|
||||||
|
|
||||||
|
// )
|
||||||
|
|
||||||
|
const petalsRequest = {
|
||||||
|
type: 'generate',
|
||||||
|
inputs: inputPrompt,
|
||||||
|
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
||||||
|
stop_sequence: stopSequence,
|
||||||
|
do_sample: 1, // enable top p and the like
|
||||||
|
temperature,
|
||||||
|
top_p: topP,
|
||||||
|
repetition_penalty: chatSettings.repetitionPenalty
|
||||||
|
} as any
|
||||||
|
if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
|
||||||
|
// Update token count
|
||||||
|
chatResponse.setPromptTokenCount(countTokens(model, providerData.knownBuffer))
|
||||||
|
ws.onmessage = event => {
|
||||||
|
// Remove updating indicator
|
||||||
|
chatRequest.updating = 1 // hide indicator, but still signal we're updating
|
||||||
|
chatRequest.updatingMessage = ''
|
||||||
|
const response = JSON.parse(event.data)
|
||||||
|
if (!response.ok) {
|
||||||
|
if (response.traceback.includes('Maximum length exceeded')) {
|
||||||
|
return chatResponse.finish('length')
|
||||||
|
}
|
||||||
|
const err = new Error('Error in response: ' + response.traceback)
|
||||||
|
console.error(err)
|
||||||
|
chatResponse.updateFromError(err.message)
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
providerData.knownBuffer += response.outputs
|
||||||
|
chatResponse.updateFromAsyncResponse(
|
||||||
{
|
{
|
||||||
model,
|
model,
|
||||||
choices: [{
|
choices: [{
|
||||||
|
@ -179,37 +279,32 @@ export const chatRequest = async (
|
||||||
finish_reason: (response.stop ? 'stop' : null)
|
finish_reason: (response.stop ? 'stop' : null)
|
||||||
}]
|
}]
|
||||||
} as any
|
} as any
|
||||||
)
|
)
|
||||||
if (chatSettings.aggressiveStop && !response.stop) {
|
if (chatSettings.aggressiveStop && !response.stop) {
|
||||||
// check if we should've stopped
|
// check if we should've stopped
|
||||||
const message = chatResponse.getMessages()[0]
|
const message = chatResponse.getMessages()[0]
|
||||||
const pad = 10 // look back 10 characters + stop sequence
|
const pad = 10 // look back 10 characters + stop sequence
|
||||||
if (message) {
|
if (message) {
|
||||||
const mc = (message.content).trim()
|
const mc = (message.content).trim()
|
||||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||||
const ss = stopSequences[i].trim()
|
const ss = stopSequences[i].trim()
|
||||||
const ind = mc.slice(0 - (ss.length + pad)).indexOf(ss)
|
const ind = mc.slice(0 - (ss.length + pad)).indexOf(ss)
|
||||||
if (ind > -1) {
|
if (ind > -1) {
|
||||||
const offset = (ss.length + pad) - ind
|
const offset = (ss.length + pad) - ind
|
||||||
message.content = mc.slice(0, mc.length - offset)
|
message.content = mc.slice(0, mc.length - offset)
|
||||||
response.stop = true
|
response.stop = true
|
||||||
updateMessages(chat.id)
|
updateMessages(chat.id)
|
||||||
chatResponse.finish()
|
chatResponse.finish()
|
||||||
ws.close()
|
if (ss !== stopSequence) {
|
||||||
}
|
providerData.knownBuffer += stopSequence
|
||||||
}
|
}
|
||||||
|
ws.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ws.onclose = () => {
|
|
||||||
chatResponse.updateFromClose()
|
|
||||||
}
|
|
||||||
ws.onerror = err => {
|
|
||||||
console.error(err)
|
|
||||||
throw err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
ws.send(JSON.stringify(petalsRequest))
|
||||||
return chatResponse
|
return chatResponse
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
Loading…
Reference in New Issue