Fix some issues with stop sequences and role sequences

This commit is contained in:
Webifi 2023-07-24 19:48:28 -05:00
parent f56e29b829
commit 38d38bf948
2 changed files with 64 additions and 35 deletions

View File

@ -25,10 +25,16 @@ export const runPetalsCompletionRequest = async (
ws.close() ws.close()
} }
signal.addEventListener('abort', abortListener) signal.addEventListener('abort', abortListener)
const stopSequences = modelDetail.stop || ['###'] const stopSequences = (modelDetail.stop || ['###', '</s>']).slice()
const stopSequence = getStopSequence(chat) const stopSequence = getStopSequence(chat)
const stopSequencesC = stopSequences.slice() let stopSequenceC = stopSequence
if (stopSequence === stopSequencesC[0]) stopSequencesC.shift() if (stopSequence !== '###') {
stopSequences.push(stopSequence)
stopSequenceC = '</s>'
}
const stopSequencesC = stopSequences.filter((ss) => {
return ss !== '###' && ss !== stopSequenceC
})
const maxTokens = getModelMaxTokens(model) const maxTokens = getModelMaxTokens(model)
let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens) let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens)
const promptTokenCount = chatResponse.getPromptTokenCount() const promptTokenCount = chatResponse.getPromptTokenCount()
@ -36,6 +42,16 @@ export const runPetalsCompletionRequest = async (
maxLen = Math.min(maxLen + promptTokenCount, maxTokens) maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
} }
chatResponse.onFinish(() => { chatResponse.onFinish(() => {
const message = chatResponse.getMessages()[0]
if (message) {
for (let i = 0, l = stopSequences.length; i < l; i++) {
const ss = stopSequences[i].trim()
if (message.content.trim().endsWith(ss)) {
message.content = message.content.trim().slice(0, message.content.trim().length - ss.length)
updateMessages(chat.id)
}
}
}
chatRequest.updating = false chatRequest.updating = false
chatRequest.updatingMessage = '' chatRequest.updatingMessage = ''
}) })
@ -55,8 +71,8 @@ export const runPetalsCompletionRequest = async (
} }
const rMessages = request.messages || [] as Message[] const rMessages = request.messages || [] as Message[]
// make sure top_p and temperature are set the way we need // make sure top_p and temperature are set the way we need
let temperature = request.temperature || 0 let temperature = request.temperature
if (isNaN(temperature as any)) temperature = 1 if (temperature === undefined || isNaN(temperature as any)) temperature = 1
if (!temperature || temperature <= 0) temperature = 0.01 if (!temperature || temperature <= 0) temperature = 0.01
let topP = request.top_p let topP = request.top_p
if (topP === undefined || isNaN(topP as any)) topP = 1 if (topP === undefined || isNaN(topP as any)) topP = 1
@ -64,7 +80,7 @@ export const runPetalsCompletionRequest = async (
// build the message array // build the message array
const inputArray = (rMessages).reduce((a, m) => { const inputArray = (rMessages).reduce((a, m) => {
const c = getRoleTag(m.role, model, chatRequest.chat) + m.content const c = getRoleTag(m.role, model, chatRequest.chat) + m.content
a.push(c) a.push(c.trim())
return a return a
}, [] as string[]) }, [] as string[])
const lastMessage = rMessages[rMessages.length - 1] const lastMessage = rMessages[rMessages.length - 1]
@ -75,12 +91,12 @@ export const runPetalsCompletionRequest = async (
type: 'generate', type: 'generate',
inputs: inputArray.join(stopSequence), inputs: inputArray.join(stopSequence),
max_new_tokens: 1, // wait for up to 1 tokens before displaying max_new_tokens: 1, // wait for up to 1 tokens before displaying
stop_sequence: stopSequence, stop_sequence: stopSequenceC,
do_sample: 1, // enable top p and the like do_sample: 1, // enable top p and the like
temperature, temperature,
top_p: topP, top_p: topP
extra_stop_sequences: stopSequencesC } as any
} if (stopSequencesC.length) petalsRequest.extra_stop_sequences = stopSequencesC
ws.send(JSON.stringify(petalsRequest)) ws.send(JSON.stringify(petalsRequest))
ws.onmessage = event => { ws.onmessage = event => {
// Remove updating indicator // Remove updating indicator
@ -106,17 +122,6 @@ export const runPetalsCompletionRequest = async (
}] }]
} as any } as any
) )
if (response.stop) {
const message = chatResponse.getMessages()[0]
if (message) {
for (let i = 0, l = stopSequences.length; i < l; i++) {
if (message.content.endsWith(stopSequences[i])) {
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
updateMessages(chat.id)
}
}
}
}
}, 1) }, 1)
} }
} }

View File

@ -42,27 +42,49 @@ const modelDetails : Record<string, ModelDetail> = {
completion: 0.000004, // $0.004 per 1000 tokens completion completion: 0.000004, // $0.004 per 1000 tokens completion
max: 16384 // 16k max token buffer max: 16384 // 16k max token buffer
}, },
'enoch/llama-65b-hf': {
type: 'Petals',
label: 'Petals - Llama-65b',
stop: ['###', '</s>'],
userStart: '<|user|>',
assistantStart: '<|[[CHARACTER_NAME]]|>',
systemStart: '',
prompt: 0.000000, // $0.000 per 1000 tokens prompt
completion: 0.000000, // $0.000 per 1000 tokens completion
max: 2048 // 2k max token buffer
},
'timdettmers/guanaco-65b': {
type: 'Petals',
label: 'Petals - Guanaco-65b',
stop: ['###', '</s>'],
userStart: '<|user|>',
assistantStart: '<|[[CHARACTER_NAME]]|>',
systemStart: '',
prompt: 0.000000, // $0.000 per 1000 tokens prompt
completion: 0.000000, // $0.000 per 1000 tokens completion
max: 2048 // 2k max token buffer
},
'meta-llama/Llama-2-70b-chat-hf': { 'meta-llama/Llama-2-70b-chat-hf': {
type: 'Petals', type: 'Petals',
label: 'Petals - Llama-2-70b-chat', label: 'Petals - Llama-2-70b-chat',
stop: ['</s>'], stop: ['###', '</s>'],
userStart: '[user]', userStart: '<|user|>',
assistantStart: '[[[CHARACTER_NAME]]]', assistantStart: '<|[[CHARACTER_NAME]]|>',
systemStart: '', systemStart: '',
prompt: 0.000000, // $0.000 per 1000 tokens prompt prompt: 0.000000, // $0.000 per 1000 tokens prompt
completion: 0.000000, // $0.000 per 1000 tokens completion completion: 0.000000, // $0.000 per 1000 tokens completion
max: 4096 // 4k max token buffer max: 4096 // 4k max token buffer
}, },
'timdettmers/guanaco-65b': { 'meta-llama/Llama-2-70b-hf': {
type: 'Petals', type: 'Petals',
label: 'Petals - guanaco-65b', label: 'Petals - Llama-2-70b',
stop: ['</s>'], stop: ['###', '</s>'],
userStart: '[user]', userStart: '<|user|>',
assistantStart: '[[[CHARACTER_NAME]]]', assistantStart: '<|[[CHARACTER_NAME]]|>',
systemStart: '', systemStart: '',
prompt: 0.000000, // $0.000 per 1000 tokens prompt prompt: 0.000000, // $0.000 per 1000 tokens prompt
completion: 0.000000, // $0.000 per 1000 tokens completion completion: 0.000000, // $0.000 per 1000 tokens completion
max: 2048 // 2k max token buffer max: 4096 // 4k max token buffer
} }
} }
@ -107,8 +129,10 @@ export const supportedModels : Record<string, ModelDetail> = {
'gpt-4-32k': modelDetails['gpt-4-32k'], 'gpt-4-32k': modelDetails['gpt-4-32k'],
'gpt-4-32k-0314': modelDetails['gpt-4-32k'], 'gpt-4-32k-0314': modelDetails['gpt-4-32k'],
'gpt-4-32k-0613': modelDetails['gpt-4-32k'], 'gpt-4-32k-0613': modelDetails['gpt-4-32k'],
'enoch/llama-65b-hf': modelDetails['enoch/llama-65b-hf'],
'timdettmers/guanaco-65b': modelDetails['timdettmers/guanaco-65b'],
'meta-llama/Llama-2-70b-hf': modelDetails['meta-llama/Llama-2-70b-hf'],
'meta-llama/Llama-2-70b-chat-hf': modelDetails['meta-llama/Llama-2-70b-chat-hf'] 'meta-llama/Llama-2-70b-chat-hf': modelDetails['meta-llama/Llama-2-70b-chat-hf']
// 'timdettmers/guanaco-65b': modelDetails['timdettmers/guanaco-65b']
} }
const lookupList = { const lookupList = {
@ -154,27 +178,27 @@ export const getEndpoint = (model: Model): string => {
} }
export const getStopSequence = (chat: Chat): string => { export const getStopSequence = (chat: Chat): string => {
return valueOf(chat.id, getChatSettingObjectByKey('stopSequence').placeholder) return chat.settings.stopSequence || valueOf(chat.id, getChatSettingObjectByKey('stopSequence').placeholder)
} }
export const getUserStart = (chat: Chat): string => { export const getUserStart = (chat: Chat): string => {
return mergeProfileFields( return mergeProfileFields(
chat.settings, chat.settings,
valueOf(chat.id, getChatSettingObjectByKey('userMessageStart').placeholder) chat.settings.userMessageStart || valueOf(chat.id, getChatSettingObjectByKey('userMessageStart').placeholder)
) )
} }
export const getAssistantStart = (chat: Chat): string => { export const getAssistantStart = (chat: Chat): string => {
return mergeProfileFields( return mergeProfileFields(
chat.settings, chat.settings,
valueOf(chat.id, getChatSettingObjectByKey('assistantMessageStart').placeholder) chat.settings.assistantMessageStart || valueOf(chat.id, getChatSettingObjectByKey('assistantMessageStart').placeholder)
) )
} }
export const getSystemStart = (chat: Chat): string => { export const getSystemStart = (chat: Chat): string => {
return mergeProfileFields( return mergeProfileFields(
chat.settings, chat.settings,
valueOf(chat.id, getChatSettingObjectByKey('systemMessageStart').placeholder) chat.settings.systemMessageStart || valueOf(chat.id, getChatSettingObjectByKey('systemMessageStart').placeholder)
) )
} }