Merge pull request #247 from Webifi/main

Close websocket connections, Format Llama-2-chat requests to Meta spec.
This commit is contained in:
Niek van der Maas 2023-07-29 07:00:45 +02:00 committed by GitHub
commit 192e5ea5a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 216 additions and 21 deletions

View File

@ -1,7 +1,7 @@
<script context="module" lang="ts"> <script context="module" lang="ts">
import ChatCompletionResponse from './ChatCompletionResponse.svelte' import ChatCompletionResponse from './ChatCompletionResponse.svelte'
import ChatRequest from './ChatRequest.svelte' import ChatRequest from './ChatRequest.svelte'
import { getEndpoint, getModelDetail, getRoleTag, getStopSequence } from './Models.svelte' import { getDeliminator, getEndpoint, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence, getStopSequence } from './Models.svelte'
import type { ChatCompletionOpts, Message, Request } from './Types.svelte' import type { ChatCompletionOpts, Message, Request } from './Types.svelte'
import { getModelMaxTokens } from './Stats.svelte' import { getModelMaxTokens } from './Stats.svelte'
import { updateMessages } from './Storage.svelte' import { updateMessages } from './Storage.svelte'
@ -27,13 +27,18 @@ export const runPetalsCompletionRequest = async (
signal.addEventListener('abort', abortListener) signal.addEventListener('abort', abortListener)
const stopSequences = (modelDetail.stop || ['###', '</s>']).slice() const stopSequences = (modelDetail.stop || ['###', '</s>']).slice()
const stopSequence = getStopSequence(chat) const stopSequence = getStopSequence(chat)
const deliminator = getDeliminator(chat)
if (deliminator) stopSequences.unshift(deliminator)
let stopSequenceC = stopSequence let stopSequenceC = stopSequence
if (stopSequence !== '###') { if (stopSequence !== '###') {
stopSequences.push(stopSequence) stopSequences.push(stopSequence)
stopSequenceC = '</s>' stopSequenceC = '</s>'
} }
const haveSeq = {}
const stopSequencesC = stopSequences.filter((ss) => { const stopSequencesC = stopSequences.filter((ss) => {
return ss !== '###' && ss !== stopSequenceC const have = haveSeq[ss]
haveSeq[ss] = true
return !have && ss !== '###' && ss !== stopSequenceC
}) })
const maxTokens = getModelMaxTokens(model) const maxTokens = getModelMaxTokens(model)
let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens) let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens)
@ -54,6 +59,7 @@ export const runPetalsCompletionRequest = async (
} }
chatRequest.updating = false chatRequest.updating = false
chatRequest.updatingMessage = '' chatRequest.updatingMessage = ''
ws.close()
}) })
ws.onopen = () => { ws.onopen = () => {
ws.send(JSON.stringify({ ws.send(JSON.stringify({
@ -69,7 +75,21 @@ export const runPetalsCompletionRequest = async (
console.error(err) console.error(err)
throw err throw err
} }
const rMessages = request.messages || [] as Message[] // Enforce strict order of messages
const fMessages = (request.messages || [] as Message[])
const rMessages = fMessages.reduce((a, m, i) => {
a.push(m)
const nm = fMessages[i + 1]
if (m.role === 'system' && (!nm || nm.role !== 'user')) {
const nc = {
role: 'user',
content: ''
} as Message
a.push(nc)
}
return a
},
[] as Message[])
// make sure top_p and temperature are set the way we need // make sure top_p and temperature are set the way we need
let temperature = request.temperature let temperature = request.temperature
if (temperature === undefined || isNaN(temperature as any)) temperature = 1 if (temperature === undefined || isNaN(temperature as any)) temperature = 1
@ -78,18 +98,47 @@ export const runPetalsCompletionRequest = async (
if (topP === undefined || isNaN(topP as any)) topP = 1 if (topP === undefined || isNaN(topP as any)) topP = 1
if (!topP || topP <= 0) topP = 0.01 if (!topP || topP <= 0) topP = 0.01
// build the message array // build the message array
const inputArray = (rMessages).reduce((a, m) => { const buildMessage = (m: Message): string => {
const c = getRoleTag(m.role, model, chatRequest.chat) + m.content return getRoleTag(m.role, model, chat) + m.content + getRoleEnd(m.role, model, chat)
a.push(c.trim())
return a
}, [] as string[])
const lastMessage = rMessages[rMessages.length - 1]
if (lastMessage && lastMessage.role !== 'assistant') {
inputArray.push(getRoleTag('assistant', model, chatRequest.chat))
} }
const inputArray = rMessages.reduce((a, m, i) => {
let c = buildMessage(m)
let replace = false
const lm = a[a.length - 1]
// Merge content if needed
if (lm) {
if (lm.role === 'system' && m.role === 'user' && c.includes('[[SYSTEM_PROMPT]]')) {
c = c.replaceAll('[[SYSTEM_PROMPT]]', lm.content)
replace = true
} else {
c = c.replaceAll('[[SYSTEM_PROMPT]]', '')
}
if (lm.role === 'user' && m.role === 'assistant' && c.includes('[[USER_PROMPT]]')) {
c = c.replaceAll('[[USER_PROMPT]]', lm.content)
replace = true
} else {
c = c.replaceAll('[[USER_PROMPT]]', '')
}
}
// Clean up merge fields on last
if (!rMessages[i + 1]) {
c = c.replaceAll('[[USER_PROMPT]]', '').replaceAll('[[SYSTEM_PROMPT]]', '')
}
const result = {
role: m.role,
content: c.trim()
} as Message
if (replace) {
a[a.length - 1] = result
} else {
a.push(result)
}
return a
}, [] as Message[])
const leadPrompt = ((inputArray[inputArray.length - 1] || {}) as Message).role !== 'assistant' ? getLeadPrompt(chat) : ''
const petalsRequest = { const petalsRequest = {
type: 'generate', type: 'generate',
inputs: inputArray.join(stopSequence), inputs: getStartSequence(chat) + inputArray.map(m => m.content).join(deliminator) + leadPrompt,
max_new_tokens: 1, // wait for up to 1 tokens before displaying max_new_tokens: 1, // wait for up to 1 tokens before displaying
stop_sequence: stopSequenceC, stop_sequence: stopSequenceC,
do_sample: 1, // enable top p and the like do_sample: 1, // enable top p and the like

View File

@ -46,6 +46,7 @@ const modelDetails : Record<string, ModelDetail> = {
type: 'Petals', type: 'Petals',
label: 'Petals - Llama-65b', label: 'Petals - Llama-65b',
stop: ['###', '</s>'], stop: ['###', '</s>'],
deliminator: '###',
userStart: '<|user|>', userStart: '<|user|>',
assistantStart: '<|[[CHARACTER_NAME]]|>', assistantStart: '<|[[CHARACTER_NAME]]|>',
systemStart: '', systemStart: '',
@ -57,6 +58,7 @@ const modelDetails : Record<string, ModelDetail> = {
type: 'Petals', type: 'Petals',
label: 'Petals - Guanaco-65b', label: 'Petals - Guanaco-65b',
stop: ['###', '</s>'], stop: ['###', '</s>'],
deliminator: '###',
userStart: '<|user|>', userStart: '<|user|>',
assistantStart: '<|[[CHARACTER_NAME]]|>', assistantStart: '<|[[CHARACTER_NAME]]|>',
systemStart: '', systemStart: '',
@ -67,10 +69,15 @@ const modelDetails : Record<string, ModelDetail> = {
'meta-llama/Llama-2-70b-chat-hf': { 'meta-llama/Llama-2-70b-chat-hf': {
type: 'Petals', type: 'Petals',
label: 'Petals - Llama-2-70b-chat', label: 'Petals - Llama-2-70b-chat',
stop: ['###', '</s>'], start: '<s>',
userStart: '<|user|>', stop: ['</s>'],
assistantStart: '<|[[CHARACTER_NAME]]|>', deliminator: ' </s><s>',
systemStart: '', userStart: '[INST][[SYSTEM_PROMPT]]',
userEnd: ' [/INST]',
assistantStart: '[[SYSTEM_PROMPT]][[USER_PROMPT]]',
assistantEnd: '',
systemStart: '<<SYS>>\n',
systemEnd: '\n<</SYS>>\n\n',
prompt: 0.000000, // $0.000 per 1000 tokens prompt prompt: 0.000000, // $0.000 per 1000 tokens prompt
completion: 0.000000, // $0.000 per 1000 tokens completion completion: 0.000000, // $0.000 per 1000 tokens completion
max: 4096 // 4k max token buffer max: 4096 // 4k max token buffer
@ -177,10 +184,29 @@ export const getEndpoint = (model: Model): string => {
} }
} }
export const getStartSequence = (chat: Chat): string => {
return mergeProfileFields(
chat.settings,
chat.settings.startSequence || valueOf(chat.id, getChatSettingObjectByKey('startSequence').placeholder)
)
}
export const getStopSequence = (chat: Chat): string => { export const getStopSequence = (chat: Chat): string => {
return chat.settings.stopSequence || valueOf(chat.id, getChatSettingObjectByKey('stopSequence').placeholder) return chat.settings.stopSequence || valueOf(chat.id, getChatSettingObjectByKey('stopSequence').placeholder)
} }
export const getDeliminator = (chat: Chat): string => {
return chat.settings.deliminator || valueOf(chat.id, getChatSettingObjectByKey('deliminator').placeholder)
}
export const getLeadPrompt = (chat: Chat): string => {
return mergeProfileFields(
chat.settings,
chat.settings.leadPrompt || valueOf(chat.id, getChatSettingObjectByKey('leadPrompt').placeholder)
)
}
export const getUserStart = (chat: Chat): string => { export const getUserStart = (chat: Chat): string => {
return mergeProfileFields( return mergeProfileFields(
chat.settings, chat.settings,
@ -188,6 +214,13 @@ export const getUserStart = (chat: Chat): string => {
) )
} }
export const getUserEnd = (chat: Chat): string => {
return mergeProfileFields(
chat.settings,
chat.settings.userMessageEnd || valueOf(chat.id, getChatSettingObjectByKey('userMessageEnd').placeholder)
)
}
export const getAssistantStart = (chat: Chat): string => { export const getAssistantStart = (chat: Chat): string => {
return mergeProfileFields( return mergeProfileFields(
chat.settings, chat.settings,
@ -195,6 +228,13 @@ export const getAssistantStart = (chat: Chat): string => {
) )
} }
export const getAssistantEnd = (chat: Chat): string => {
return mergeProfileFields(
chat.settings,
chat.settings.assistantMessageEnd || valueOf(chat.id, getChatSettingObjectByKey('assistantMessageEnd').placeholder)
)
}
export const getSystemStart = (chat: Chat): string => { export const getSystemStart = (chat: Chat): string => {
return mergeProfileFields( return mergeProfileFields(
chat.settings, chat.settings,
@ -202,6 +242,13 @@ export const getSystemStart = (chat: Chat): string => {
) )
} }
export const getSystemEnd = (chat: Chat): string => {
return mergeProfileFields(
chat.settings,
chat.settings.systemMessageEnd || valueOf(chat.id, getChatSettingObjectByKey('systemMessageEnd').placeholder)
)
}
export const getRoleTag = (role: string, model: Model, chat: Chat): string => { export const getRoleTag = (role: string, model: Model, chat: Chat): string => {
const modelDetails = getModelDetail(model) const modelDetails = getModelDetail(model)
switch (modelDetails.type) { switch (modelDetails.type) {
@ -217,6 +264,21 @@ export const getRoleTag = (role: string, model: Model, chat: Chat): string => {
} }
} }
export const getRoleEnd = (role: string, model: Model, chat: Chat): string => {
const modelDetails = getModelDetail(model)
switch (modelDetails.type) {
case 'Petals':
if (role === 'assistant') return getAssistantEnd(chat)
if (role === 'user') return getUserEnd(chat)
return getSystemEnd(chat)
case 'OpenAIDall-e':
return ''
case 'OpenAIChat':
default:
return ''
}
}
export const getTokens = (model: Model, value: string): number[] => { export const getTokens = (model: Model, value: string): number[] => {
const modelDetails = getModelDetail(model) const modelDetails = getModelDetail(model)
switch (modelDetails.type) { switch (modelDetails.type) {

View File

@ -109,11 +109,17 @@ const defaults:ChatSettings = {
hppContinuePrompt: '', hppContinuePrompt: '',
hppWithSummaryPrompt: false, hppWithSummaryPrompt: false,
imageGenerationSize: '', imageGenerationSize: '',
startSequence: '',
stopSequence: '', stopSequence: '',
aggressiveStop: false, aggressiveStop: false,
deliminator: '',
userMessageStart: '', userMessageStart: '',
userMessageEnd: '',
assistantMessageStart: '', assistantMessageStart: '',
assistantMessageEnd: '',
systemMessageStart: '', systemMessageStart: '',
systemMessageEnd: '',
leadPrompt: '',
// useResponseAlteration: false, // useResponseAlteration: false,
// responseAlterations: [], // responseAlterations: [],
isDirty: false isDirty: false
@ -514,10 +520,21 @@ const chatSettingsList: ChatSetting[] = [
type: 'number', type: 'number',
hide: isNotOpenAI hide: isNotOpenAI
}, },
{
key: 'startSequence',
name: 'Start Sequence',
title: 'Characters used to start the message chain.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).start
return val || ''
},
hide: isNotPetals
},
{ {
key: 'stopSequence', key: 'stopSequence',
name: 'Stop Sequence', name: 'Stop Sequence',
title: 'Characters used to separate messages in the message chain.', title: 'Characters used to signal end of message chain.',
type: 'text', type: 'text',
placeholder: (chatId) => { placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).stop const val = getModelDetail(getChatSettings(chatId).model).stop
@ -532,39 +549,94 @@ const chatSettingsList: ChatSetting[] = [
type: 'boolean', type: 'boolean',
hide: isNotPetals hide: isNotPetals
}, },
{
key: 'deliminator',
name: 'Deliminator Sequence',
title: 'Characters used to separate messages in the message chain.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).deliminator
return val || ''
},
hide: isNotPetals
},
{ {
key: 'userMessageStart', key: 'userMessageStart',
name: 'User Message Start Sequence', name: 'User Message Start Sequence',
title: 'Sequence to denote user messages in the message chain.', title: 'Sequence to denote start of user messages in the message chain.',
type: 'text', type: 'textarea',
placeholder: (chatId) => { placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).userStart const val = getModelDetail(getChatSettings(chatId).model).userStart
return val || '' return val || ''
}, },
hide: isNotPetals hide: isNotPetals
}, },
{
key: 'userMessageEnd',
name: 'User Message End Sequence',
title: 'Sequence to denote end of user messages in the message chain.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).userEnd
return val || ''
},
hide: isNotPetals
},
{ {
key: 'assistantMessageStart', key: 'assistantMessageStart',
name: 'Assistant Message Start Sequence', name: 'Assistant Message Start Sequence',
title: 'Sequence to denote assistant messages in the message chain.', title: 'Sequence to denote assistant messages in the message chain.',
type: 'text', type: 'textarea',
placeholder: (chatId) => { placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).assistantStart const val = getModelDetail(getChatSettings(chatId).model).assistantStart
return val || '' return val || ''
}, },
hide: isNotPetals hide: isNotPetals
}, },
{
key: 'assistantMessageEnd',
name: 'Assistant Message End Sequence',
title: 'Sequence to denote end of assistant messages in the message chain.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).assistantEnd
return val || ''
},
hide: isNotPetals
},
{
key: 'leadPrompt',
name: 'Completion Lead Sequence ',
title: 'Sequence to hint the LLM should answer as assistant.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).leadPrompt
return val || ''
},
hide: isNotPetals
},
{ {
key: 'systemMessageStart', key: 'systemMessageStart',
name: 'System Message Start Sequence', name: 'System Message Start Sequence',
title: 'Sequence to denote system messages in the message chain.', title: 'Sequence to denote system messages in the message chain.',
type: 'text', type: 'textarea',
placeholder: (chatId) => { placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).systemStart const val = getModelDetail(getChatSettings(chatId).model).systemStart
return val || '' return val || ''
}, },
hide: isNotPetals hide: isNotPetals
}, },
{
key: 'systemMessageEnd',
name: 'System Message End Sequence',
title: 'Sequence to denote end of system messages in the message chain.',
type: 'textarea',
placeholder: (chatId) => {
const val = getModelDetail(getChatSettings(chatId).model).systemEnd
return val || ''
},
hide: isNotPetals
},
{ {
// logit bias editor not implemented yet // logit bias editor not implemented yet
key: 'logit_bias', key: 'logit_bias',

View File

@ -12,10 +12,16 @@ export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'Petals'
export type ModelDetail = { export type ModelDetail = {
type: RequestType; type: RequestType;
label?: string; label?: string;
start?: string;
stop?: string[]; stop?: string[];
deliminator?: string;
userStart?: string, userStart?: string,
userEnd?: string,
assistantStart?: string, assistantStart?: string,
assistantEnd?: string,
systemStart?: string, systemStart?: string,
systemEnd?: string,
leadPrompt?: string,
prompt: number; prompt: number;
completion: number; completion: number;
max: number; max: number;
@ -113,11 +119,17 @@ export type ChatSettings = {
trainingPrompts?: Message[]; trainingPrompts?: Message[];
useResponseAlteration?: boolean; useResponseAlteration?: boolean;
responseAlterations?: ResponseAlteration[]; responseAlterations?: ResponseAlteration[];
startSequence: string;
stopSequence: string; stopSequence: string;
aggressiveStop: boolean; aggressiveStop: boolean;
deliminator: string;
userMessageStart: string; userMessageStart: string;
userMessageEnd: string;
assistantMessageStart: string; assistantMessageStart: string;
assistantMessageEnd: string;
leadPrompt: string;
systemMessageStart: string; systemMessageStart: string;
systemMessageEnd: string;
isDirty?: boolean; isDirty?: boolean;
} & Request; } & Request;