From e3095bdf41fce86ca1fd44f4fc9840866125c9d3 Mon Sep 17 00:00:00 2001 From: Webifi Date: Sun, 10 Sep 2023 19:32:33 -0500 Subject: [PATCH] Fix some types --- src/lib/providers/openai/models.svelte | 2 +- src/lib/providers/petals/models.svelte | 10 ++++++++-- src/lib/providers/petals/request.svelte | 15 +++++++++------ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/lib/providers/openai/models.svelte b/src/lib/providers/openai/models.svelte index c46fbe5..2e801ae 100644 --- a/src/lib/providers/openai/models.svelte +++ b/src/lib/providers/openai/models.svelte @@ -23,7 +23,7 @@ const hiddenSettings = { repetitionPenalty: true, holdSocket: true // leadPrompt: true -} +} as any const chatModelBase = { type: 'chat', diff --git a/src/lib/providers/petals/models.svelte b/src/lib/providers/petals/models.svelte index 44467c7..5e3ef96 100644 --- a/src/lib/providers/petals/models.svelte +++ b/src/lib/providers/petals/models.svelte @@ -14,7 +14,7 @@ const hideSettings = { n: true, presence_penalty: true, frequency_penalty: true -} +} as any const chatModelBase = { type: 'instruct', // Used for chat, but these models operate like instruct models -- you have to manually structure the messages sent to them @@ -85,8 +85,14 @@ export const chatModels : Record = { }, 'stabilityai/StableBeluga2': { ...chatModelBase, - label: 'Petals - StableBeluga-2' + label: 'Petals - StableBeluga-2-70b' } + // 'tiiuae/falcon-180B-chat': { + // ...chatModelBase, + // start: '###', + // stop: ['###', '', '<|endoftext|>'], + // label: 'Petals - Falcon-180b-chat' + // } } \ No newline at end of file diff --git a/src/lib/providers/petals/request.svelte b/src/lib/providers/petals/request.svelte index 33eb7be..70c777d 100644 --- a/src/lib/providers/petals/request.svelte +++ b/src/lib/providers/petals/request.svelte @@ -70,13 +70,15 @@ export const chatRequest = async ( stopSequences = stopSequences.sort((a, b) => b.length - a.length) const stopSequencesC = stopSequences.filter(s => s !== stopSequence) const maxTokens = getModelMaxTokens(model) + const userAfterSystem = true // Enforce strict order of messages const fMessages = (request.messages || [] as Message[]) const rMessages = fMessages.reduce((a, m, i) => { a.push(m) + // if (m.role === 'system') m.content = m.content.trim() const nm = fMessages[i + 1] - if (m.role === 'system' && (!nm || nm.role !== 'user')) { + if (userAfterSystem && m.role === 'system' && (!nm || nm.role !== 'user')) { const nc = { role: 'user', content: '' @@ -97,7 +99,7 @@ export const chatRequest = async ( const buildMessage = (m: Message): string => { return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat) } - const buildInputArray = (a) => { + const buildInputArray = (a: Message[]) => { return a.reduce((a, m, i) => { let c = buildMessage(m) let replace = false @@ -141,7 +143,7 @@ export const chatRequest = async ( } // const inputArray = buildInputArray(rMessages).map(m => m.content) const lInputArray = doLead - ? buildInputArray(rMessages.slice(0, -1)).map(m => m.content) + ? (rMessages.length > 1 ? buildInputArray(rMessages.slice(0, -1)).map(m => m.content) : []) : buildInputArray(rMessages.slice()).map(m => m.content) const nInputArray = buildInputArray(rMessages.slice(-1)).map(m => m.content) const leadPrompt = (leadPromptSequence && doLead) ? delimiter + leadPromptSequence : '' @@ -194,7 +196,7 @@ export const chatRequest = async ( throw err } // console.warn('got new ws') - inputPrompt = lastPrompt + (doLead ? delimiter : '') + inputPrompt = lastPrompt + (doLead && lInputArray.length ? delimiter : '') providerData.knownBuffer = '' providerData.ws = nws resolve(nws) @@ -217,11 +219,12 @@ export const chatRequest = async ( } // update with real count chatResponse.setPromptTokenCount(promptTokenCount) - nws.send(JSON.stringify({ + const req = { type: 'open_inference_session', model, max_length: chatSettings.holdSocket ? maxTokens : maxLen - })) + } as any + nws.send(JSON.stringify(req)) } })