Add distinction between chat and instruct models

This commit is contained in:
Webifi 2023-08-16 15:20:07 -05:00
parent f4d9774423
commit cb2b9e07f4
7 changed files with 36 additions and 22 deletions

View File

@ -8,7 +8,7 @@
import { getDefaultModel, getRequestSettingList } from './Settings.svelte' import { getDefaultModel, getRequestSettingList } from './Settings.svelte'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
import { get } from 'svelte/store' import { get } from 'svelte/store'
import { getModelDetail } from './Models.svelte' import { getLeadPrompt, getModelDetail } from './Models.svelte'
export class ChatRequest { export class ChatRequest {
constructor () { constructor () {
@ -238,9 +238,10 @@ export class ChatRequest {
const lastMessage = messages[messages.length - 1] const lastMessage = messages[messages.length - 1]
const isContinue = lastMessage?.role === 'assistant' && lastMessage.finish_reason === 'length' const isContinue = lastMessage?.role === 'assistant' && lastMessage.finish_reason === 'length'
const isUserPrompt = lastMessage?.role === 'user' const isUserPrompt = lastMessage?.role === 'user'
if (hiddenPromptPrefix && (isUserPrompt || isContinue)) { let results: Message[] = []
let injectedPrompt = false let injectedPrompt = false
const results = hiddenPromptPrefix.split(/[\s\r\n]*::EOM::[\s\r\n]*/).reduce((a, m) => { if (hiddenPromptPrefix && (isUserPrompt || isContinue)) {
results = hiddenPromptPrefix.split(/[\s\r\n]*::EOM::[\s\r\n]*/).reduce((a, m) => {
m = m.trim() m = m.trim()
if (m.length) { if (m.length) {
if (m.match(/\[\[USER_PROMPT\]\]/)) { if (m.match(/\[\[USER_PROMPT\]\]/)) {
@ -265,9 +266,21 @@ export class ChatRequest {
} }
} }
if (injectedPrompt) messages.pop() if (injectedPrompt) messages.pop()
return results
} }
return [] const model = this.getModel()
const messageDetail = getModelDetail(model)
if (getLeadPrompt(this.getChat()).trim() && messageDetail.type === 'chat') {
const lastMessage = (results.length && injectedPrompt && !isContinue) ? results[results.length - 1] : messages[messages.length - 1]
if (lastMessage?.role !== 'assistant') {
const leadMessage = { role: 'assistant', content: getLeadPrompt(this.getChat()) } as Message
if (insert) {
messages.push(leadMessage)
} else {
results.push(leadMessage)
}
}
}
return results
} }
/** /**

View File

@ -597,17 +597,6 @@ const chatSettingsList: ChatSetting[] = [
}, },
hide: hideModelSetting hide: hideModelSetting
}, },
{
key: 'leadPrompt',
name: 'Completion Lead Sequence ',
title: 'Sequence to hint the LLM should answer as assistant.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).leadPrompt
return val || ''
},
hide: hideModelSetting
},
{ {
key: 'systemMessageStart', key: 'systemMessageStart',
name: 'System Message Start Sequence', name: 'System Message Start Sequence',
@ -630,6 +619,17 @@ const chatSettingsList: ChatSetting[] = [
}, },
hide: hideModelSetting hide: hideModelSetting
}, },
{
key: 'leadPrompt',
name: 'Completion Lead Sequence',
title: 'Sequence to hint to answer as assistant.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).leadPrompt
return val || ''
},
hide: hideModelSetting
},
{ {
// logit bias editor not implemented yet // logit bias editor not implemented yet
key: 'logit_bias', key: 'logit_bias',

View File

@ -6,7 +6,7 @@
export type Model = typeof supportedChatModelKeys[number]; export type Model = typeof supportedChatModelKeys[number];
export type RequestType = 'chat' | 'image' export type RequestType = 'chat' | 'instruct' | 'image'
export type Usage = { export type Usage = {
completion_tokens: number; completion_tokens: number;

View File

@ -18,10 +18,10 @@ const hiddenSettings = {
userMessageEnd: true, userMessageEnd: true,
assistantMessageStart: true, assistantMessageStart: true,
assistantMessageEnd: true, assistantMessageEnd: true,
leadPrompt: true,
systemMessageStart: true, systemMessageStart: true,
systemMessageEnd: true, systemMessageEnd: true,
repititionPenalty: true repititionPenalty: true
// leadPrompt: true
} }
const chatModelBase = { const chatModelBase = {

View File

@ -49,7 +49,7 @@ const getSupportedModels = async (): Promise<Record<string, boolean>> => {
export const checkModel = async (modelDetail: ModelDetail) => { export const checkModel = async (modelDetail: ModelDetail) => {
const supportedModels = await getSupportedModels() const supportedModels = await getSupportedModels()
if (modelDetail.type === 'chat') { if (modelDetail.type === 'chat' || modelDetail.type === 'instruct') {
modelDetail.enabled = !!supportedModels[modelDetail.modelQuery || ''] modelDetail.enabled = !!supportedModels[modelDetail.modelQuery || '']
} else { } else {
// image request. If we have any models, allow image endpoint // image request. If we have any models, allow image endpoint

View File

@ -17,7 +17,7 @@ const hideSettings = {
} }
const chatModelBase = { const chatModelBase = {
type: 'chat', type: 'instruct', // Used for chat, but these models operate like instruct models -- you have to manually structure the messages sent to them
help: 'Below are the settings that can be changed for the API calls. See <a target="_blank" href="https://platform.openai.com/docs/api-reference/chat/create">this overview</a> to start, though not all settings translate to Petals.', help: 'Below are the settings that can be changed for the API calls. See <a target="_blank" href="https://platform.openai.com/docs/api-reference/chat/create">this overview</a> to start, though not all settings translate to Petals.',
check: checkModel, check: checkModel,
start: '<s>', start: '<s>',
@ -45,7 +45,7 @@ const chatModelBase = {
return prompts.reduce((a, m) => { return prompts.reduce((a, m) => {
a += countMessageTokens(m, model, chat) a += countMessageTokens(m, model, chat)
return a return a
}, 0) + countTokens(model, getStartSequence(chat)) + countTokens(model, getLeadPrompt(chat)) }, 0) + countTokens(model, getStartSequence(chat)) + ((prompts[prompts.length - 1] || {}).role !== 'assistant' ? countTokens(model, getLeadPrompt(chat)) : 0)
} }
} as ModelDetail } as ModelDetail
@ -74,6 +74,7 @@ export const chatModels : Record<string, ModelDetail> = {
assistantStart: '[[SYSTEM_PROMPT]][[USER_PROMPT]]', assistantStart: '[[SYSTEM_PROMPT]][[USER_PROMPT]]',
systemStart: '<<SYS>>\n', systemStart: '<<SYS>>\n',
systemEnd: '\n<</SYS>>\n\n' systemEnd: '\n<</SYS>>\n\n'
// leadPrompt: ''
}, },
'stabilityai/StableBeluga2': { 'stabilityai/StableBeluga2': {
...chatModelBase, ...chatModelBase,

View File

@ -8,7 +8,7 @@ export const set = (opt: Record<string, any>) => {
} }
export const checkModel = async (modelDetail: ModelDetail) => { export const checkModel = async (modelDetail: ModelDetail) => {
if (modelDetail.type === 'chat') { if (modelDetail.type === 'chat' || modelDetail.type === 'instruct') {
modelDetail.enabled = get(globalStorage).enablePetals modelDetail.enabled = get(globalStorage).enablePetals
} }
} }