diff --git a/src/lib/ChatRequest.svelte b/src/lib/ChatRequest.svelte index c387c78..e9a31b5 100644 --- a/src/lib/ChatRequest.svelte +++ b/src/lib/ChatRequest.svelte @@ -8,7 +8,7 @@ import { getDefaultModel, getRequestSettingList } from './Settings.svelte' import { v4 as uuidv4 } from 'uuid' import { get } from 'svelte/store' - import { getModelDetail } from './Models.svelte' + import { getLeadPrompt, getModelDetail } from './Models.svelte' export class ChatRequest { constructor () { @@ -238,9 +238,10 @@ export class ChatRequest { const lastMessage = messages[messages.length - 1] const isContinue = lastMessage?.role === 'assistant' && lastMessage.finish_reason === 'length' const isUserPrompt = lastMessage?.role === 'user' + let results: Message[] = [] + let injectedPrompt = false if (hiddenPromptPrefix && (isUserPrompt || isContinue)) { - let injectedPrompt = false - const results = hiddenPromptPrefix.split(/[\s\r\n]*::EOM::[\s\r\n]*/).reduce((a, m) => { + results = hiddenPromptPrefix.split(/[\s\r\n]*::EOM::[\s\r\n]*/).reduce((a, m) => { m = m.trim() if (m.length) { if (m.match(/\[\[USER_PROMPT\]\]/)) { @@ -265,9 +266,21 @@ export class ChatRequest { } } if (injectedPrompt) messages.pop() - return results } - return [] + const model = this.getModel() + const messageDetail = getModelDetail(model) + if (getLeadPrompt(this.getChat()).trim() && messageDetail.type === 'chat') { + const lastMessage = (results.length && injectedPrompt && !isContinue) ? results[results.length - 1] : messages[messages.length - 1] + if (lastMessage?.role !== 'assistant') { + const leadMessage = { role: 'assistant', content: getLeadPrompt(this.getChat()) } as Message + if (insert) { + messages.push(leadMessage) + } else { + results.push(leadMessage) + } + } + } + return results } /** diff --git a/src/lib/Settings.svelte b/src/lib/Settings.svelte index bd502c7..d17151c 100644 --- a/src/lib/Settings.svelte +++ b/src/lib/Settings.svelte @@ -597,17 +597,6 @@ const chatSettingsList: ChatSetting[] = [ }, hide: hideModelSetting }, - { - key: 'leadPrompt', - name: 'Completion Lead Sequence ', - title: 'Sequence to hint the LLM should answer as assistant.', - type: 'textarea', - placeholder: (chatId) => { - const val = getModelDetail(getChatSettings(chatId).model).leadPrompt - return val || '' - }, - hide: hideModelSetting - }, { key: 'systemMessageStart', name: 'System Message Start Sequence', @@ -630,6 +619,17 @@ const chatSettingsList: ChatSetting[] = [ }, hide: hideModelSetting }, + { + key: 'leadPrompt', + name: 'Completion Lead Sequence', + title: 'Sequence to hint to answer as assistant.', + type: 'textarea', + placeholder: (chatId) => { + const val = getModelDetail(getChatSettings(chatId).model).leadPrompt + return val || '' + }, + hide: hideModelSetting + }, { // logit bias editor not implemented yet key: 'logit_bias', diff --git a/src/lib/Types.svelte b/src/lib/Types.svelte index e94c5cf..aac387f 100644 --- a/src/lib/Types.svelte +++ b/src/lib/Types.svelte @@ -6,7 +6,7 @@ export type Model = typeof supportedChatModelKeys[number]; -export type RequestType = 'chat' | 'image' +export type RequestType = 'chat' | 'instruct' | 'image' export type Usage = { completion_tokens: number; diff --git a/src/lib/providers/openai/models.svelte b/src/lib/providers/openai/models.svelte index 3700d21..09d52a0 100644 --- a/src/lib/providers/openai/models.svelte +++ b/src/lib/providers/openai/models.svelte @@ -18,10 +18,10 @@ const hiddenSettings = { userMessageEnd: true, assistantMessageStart: true, assistantMessageEnd: true, - leadPrompt: true, systemMessageStart: true, systemMessageEnd: true, repititionPenalty: true + // leadPrompt: true } const chatModelBase = { diff --git a/src/lib/providers/openai/util.svelte b/src/lib/providers/openai/util.svelte index 10a46e2..9f4190e 100644 --- a/src/lib/providers/openai/util.svelte +++ b/src/lib/providers/openai/util.svelte @@ -49,7 +49,7 @@ const getSupportedModels = async (): Promise> => { export const checkModel = async (modelDetail: ModelDetail) => { const supportedModels = await getSupportedModels() - if (modelDetail.type === 'chat') { + if (modelDetail.type === 'chat' || modelDetail.type === 'instruct') { modelDetail.enabled = !!supportedModels[modelDetail.modelQuery || ''] } else { // image request. If we have any models, allow image endpoint diff --git a/src/lib/providers/petals/models.svelte b/src/lib/providers/petals/models.svelte index cf57be4..841996e 100644 --- a/src/lib/providers/petals/models.svelte +++ b/src/lib/providers/petals/models.svelte @@ -17,7 +17,7 @@ const hideSettings = { } const chatModelBase = { - type: 'chat', + type: 'instruct', // Used for chat, but these models operate like instruct models -- you have to manually structure the messages sent to them help: 'Below are the settings that can be changed for the API calls. See this overview to start, though not all settings translate to Petals.', check: checkModel, start: '', @@ -45,7 +45,7 @@ const chatModelBase = { return prompts.reduce((a, m) => { a += countMessageTokens(m, model, chat) return a - }, 0) + countTokens(model, getStartSequence(chat)) + countTokens(model, getLeadPrompt(chat)) + }, 0) + countTokens(model, getStartSequence(chat)) + ((prompts[prompts.length - 1] || {}).role !== 'assistant' ? countTokens(model, getLeadPrompt(chat)) : 0) } } as ModelDetail @@ -74,6 +74,7 @@ export const chatModels : Record = { assistantStart: '[[SYSTEM_PROMPT]][[USER_PROMPT]]', systemStart: '<>\n', systemEnd: '\n<>\n\n' + // leadPrompt: '' }, 'stabilityai/StableBeluga2': { ...chatModelBase, diff --git a/src/lib/providers/petals/util.svelte b/src/lib/providers/petals/util.svelte index 9da7d56..c7f5da8 100644 --- a/src/lib/providers/petals/util.svelte +++ b/src/lib/providers/petals/util.svelte @@ -8,7 +8,7 @@ export const set = (opt: Record) => { } export const checkModel = async (modelDetail: ModelDetail) => { - if (modelDetail.type === 'chat') { + if (modelDetail.type === 'chat' || modelDetail.type === 'instruct') { modelDetail.enabled = get(globalStorage).enablePetals } }