More changes for Petals integration

This commit is contained in:
Webifi 2023-07-22 13:24:18 -05:00
parent df222e7028
commit 9a6004c55d
9 changed files with 271 additions and 218 deletions

View File

@ -5,12 +5,12 @@
const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
const endpointPetals = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
export const getApiBase = ():string => apiBase
export const getEndpointCompletions = ():string => endpointCompletions
export const getEndpointGenerations = ():string => endpointGenerations
export const getEndpointModels = ():string => endpointModels
export const getEndpointEmbeddings = ():string => endpointEmbeddings
export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket
export const getPetals = ():string => endpointPetals
</script>

View File

@ -65,6 +65,10 @@ export class ChatCompletionResponse {
this.promptTokenCount = tokens
}
getPromptTokenCount (): number {
return this.promptTokenCount
}
async updateImageFromSyncResponse (response: ResponseImage, prompt: string, model: Model) {
this.setModel(model)
for (let i = 0; i < response.data.length; i++) {

View File

@ -6,10 +6,11 @@
import { deleteMessage, getChatSettingValueNullDefault, insertMessages, getApiKey, addError, currentChatMessages, getMessages, updateMessages, deleteSummaryMessage } from './Storage.svelte'
import { scrollToBottom, scrollToMessage } from './Util.svelte'
import { getRequestSettingList, defaultModel } from './Settings.svelte'
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
import { v4 as uuidv4 } from 'uuid'
import { get } from 'svelte/store'
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
import { getEndpoint, getModelDetail } from './Models.svelte'
import { runOpenAiCompletionRequest } from './ChatRequestOpenAi.svelte'
import { runPetalsCompletionRequest } from './ChatRequestPetals.svelte'
export class ChatRequest {
constructor () {
@ -27,6 +28,14 @@ export class ChatRequest {
this.chat = chat
}
getChat (): Chat {
return this.chat
}
getChatSettings (): ChatSettings {
return this.chat.settings
}
// Common error handler
async handleError (response) {
let errorResponse
@ -258,193 +267,10 @@ export class ChatRequest {
_this.controller = new AbortController()
const signal = _this.controller.signal
if (modelDetail.type === 'PetalsV2Websocket') {
// Petals
const ws = new WebSocket(getEndpoint(model))
const abortListener = (e:Event) => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromError('User aborted request.')
signal.removeEventListener('abort', abortListener)
ws.close()
}
signal.addEventListener('abort', abortListener)
const stopSequences = modelDetail.stop || ['###']
const stopSequencesC = stopSequences.slice()
const stopSequence = stopSequencesC.shift()
chatResponse.onFinish(() => {
_this.updating = false
_this.updatingMessage = ''
})
ws.onopen = () => {
ws.send(JSON.stringify({
type: 'open_inference_session',
model,
max_length: maxTokens || opts.maxTokens
}))
ws.onmessage = event => {
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error opening socket: ' + response.traceback)
console.error(err)
throw err
}
const rMessages = request.messages || [] as Message[]
const inputArray = (rMessages).reduce((a, m) => {
const c = getRoleTag(m.role, model, chatSettings) + m.content
a.push(c)
return a
}, [] as string[])
const lastMessage = rMessages[rMessages.length - 1]
if (lastMessage && lastMessage.role !== 'assistant') {
inputArray.push(getRoleTag('assistant', model, chatSettings))
}
const petalsRequest = {
type: 'generate',
inputs: (request.messages || [] as Message[]).reduce((a, m) => {
const c = getRoleTag(m.role, model, chatSettings) + m.content
a.push(c)
return a
}, [] as string[]).join(stopSequence),
max_new_tokens: 3, // wait for up to 3 tokens before displaying
stop_sequence: stopSequence,
doSample: 1,
temperature: request.temperature || 0,
top_p: request.top_p || 0,
extra_stop_sequences: stopSequencesC
}
ws.send(JSON.stringify(petalsRequest))
ws.onmessage = event => {
// Remove updating indicator
_this.updating = 1 // hide indicator, but still signal we're updating
_this.updatingMessage = ''
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error in response: ' + response.traceback)
console.error(err)
throw err
}
window.setTimeout(() => {
chatResponse.updateFromAsyncResponse(
{
model,
choices: [{
delta: {
content: response.outputs,
role: 'assistant'
},
finish_reason: (response.stop ? 'stop' : null)
}]
} as any
)
if (response.stop) {
const message = chatResponse.getMessages()[0]
if (message) {
for (let i = 0, l = stopSequences.length; i < l; i++) {
if (message.content.endsWith(stopSequences[i])) {
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
updateMessages(chatId)
}
}
}
}
}, 1)
}
}
ws.onclose = () => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromClose()
}
ws.onerror = err => {
console.error(err)
throw err
}
}
if (modelDetail.type === 'Petals') {
await runPetalsCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
} else {
// OpenAI
const abortListener = (e:Event) => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromError('User aborted request.')
signal.removeEventListener('abort', abortListener)
}
signal.addEventListener('abort', abortListener)
const fetchOptions = {
method: 'POST',
headers: {
Authorization: `Bearer ${getApiKey()}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(request),
signal
}
if (opts.streaming) {
/**
* Streaming request/response
* We'll get the response a token at a time, as soon as they are ready
*/
chatResponse.onFinish(() => {
_this.updating = false
_this.updatingMessage = ''
})
fetchEventSource(getEndpoint(model), {
...fetchOptions,
openWhenHidden: true,
onmessage (ev) {
// Remove updating indicator
_this.updating = 1 // hide indicator, but still signal we're updating
_this.updatingMessage = ''
// console.log('ev.data', ev.data)
if (!chatResponse.hasFinished()) {
if (ev.data === '[DONE]') {
// ?? anything to do when "[DONE]"?
} else {
const data = JSON.parse(ev.data)
// console.log('data', data)
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
}
}
},
onclose () {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromClose()
},
onerror (err) {
console.error(err)
throw err
},
async onopen (response) {
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
// everything's good
} else {
// client-side errors are usually non-retriable:
await _this.handleError(response)
}
}
}).catch(err => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromError(err.message)
})
} else {
/**
* Non-streaming request/response
* We'll get the response all at once, after a long delay
*/
const response = await fetch(getEndpoint(model), fetchOptions)
if (!response.ok) {
await _this.handleError(response)
} else {
const json = await response.json()
// Remove updating indicator
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromSyncResponse(json)
}
}
await runOpenAiCompletionRequest(request, _this as any, chatResponse as any, signal, opts)
}
} catch (e) {
// console.error(e)
@ -456,7 +282,7 @@ export class ChatRequest {
return chatResponse
}
private getModel (): Model {
getModel (): Model {
return this.chat.settings.model || defaultModel
}

View File

@ -0,0 +1,100 @@
<script context="module" lang="ts">
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
import ChatCompletionResponse from './ChatCompletionResponse.svelte'
import ChatRequest from './ChatRequest.svelte'
import { getEndpoint } from './Models.svelte'
import { getApiKey } from './Storage.svelte'
import type { ChatCompletionOpts, Request } from './Types.svelte'
export const runOpenAiCompletionRequest = async (
request: Request,
chatRequest: ChatRequest,
chatResponse: ChatCompletionResponse,
signal: AbortSignal,
opts: ChatCompletionOpts) => {
// OpenAI Request
const model = chatRequest.getModel()
const abortListener = (e:Event) => {
chatRequest.updating = false
chatRequest.updatingMessage = ''
chatResponse.updateFromError('User aborted request.')
chatRequest.removeEventListener('abort', abortListener)
}
signal.addEventListener('abort', abortListener)
const fetchOptions = {
method: 'POST',
headers: {
Authorization: `Bearer ${getApiKey()}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(request),
signal
}
if (opts.streaming) {
/**
* Streaming request/response
* We'll get the response a token at a time, as soon as they are ready
*/
chatResponse.onFinish(() => {
chatRequest.updating = false
chatRequest.updatingMessage = ''
})
fetchEventSource(getEndpoint(model), {
...fetchOptions,
openWhenHidden: true,
onmessage (ev) {
// Remove updating indicator
chatRequest.updating = 1 // hide indicator, but still signal we're updating
chatRequest.updatingMessage = ''
// console.log('ev.data', ev.data)
if (!chatResponse.hasFinished()) {
if (ev.data === '[DONE]') {
// ?? anything to do when "[DONE]"?
} else {
const data = JSON.parse(ev.data)
// console.log('data', data)
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
}
}
},
onclose () {
chatRequest.updating = false
chatRequest.updatingMessage = ''
chatResponse.updateFromClose()
},
onerror (err) {
console.error(err)
throw err
},
async onopen (response) {
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
// everything's good
} else {
// client-side errors are usually non-retriable:
await chatRequest.handleError(response)
}
}
}).catch(err => {
chatRequest.updating = false
chatRequest.updatingMessage = ''
chatResponse.updateFromError(err.message)
})
} else {
/**
* Non-streaming request/response
* We'll get the response all at once, after a long delay
*/
const response = await fetch(getEndpoint(model), fetchOptions)
if (!response.ok) {
await chatRequest.handleError(response)
} else {
const json = await response.json()
// Remove updating indicator
chatRequest.updating = false
chatRequest.updatingMessage = ''
chatResponse.updateFromSyncResponse(json)
}
}
}
</script>

View File

@ -0,0 +1,126 @@
<script context="module" lang="ts">
import ChatCompletionResponse from './ChatCompletionResponse.svelte'
import ChatRequest from './ChatRequest.svelte'
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
import type { ChatCompletionOpts, Message, Request } from './Types.svelte'
import { getModelMaxTokens } from './Stats.svelte'
import { updateMessages } from './Storage.svelte'
export const runPetalsCompletionRequest = async (
request: Request,
chatRequest: ChatRequest,
chatResponse: ChatCompletionResponse,
signal: AbortSignal,
opts: ChatCompletionOpts) => {
// Petals
const model = chatRequest.getModel()
const modelDetail = getModelDetail(model)
const ws = new WebSocket(getEndpoint(model))
const abortListener = (e:Event) => {
chatRequest.updating = false
chatRequest.updatingMessage = ''
chatResponse.updateFromError('User aborted request.')
signal.removeEventListener('abort', abortListener)
ws.close()
}
signal.addEventListener('abort', abortListener)
const startSequences = modelDetail.start || []
const startSequence = startSequences[0] || ''
const stopSequences = modelDetail.stop || ['###']
const stopSequencesC = stopSequences.slice()
const stopSequence = stopSequencesC.shift()
const maxTokens = getModelMaxTokens(model)
let maxLen = Math.min(opts.maxTokens || chatRequest.chat.max_tokens || maxTokens, maxTokens)
const promptTokenCount = chatResponse.getPromptTokenCount()
if (promptTokenCount > maxLen) {
maxLen = Math.min(maxLen + promptTokenCount, maxTokens)
}
chatResponse.onFinish(() => {
chatRequest.updating = false
chatRequest.updatingMessage = ''
})
ws.onopen = () => {
ws.send(JSON.stringify({
type: 'open_inference_session',
model,
max_length: maxLen
}))
ws.onmessage = event => {
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error opening socket: ' + response.traceback)
console.error(err)
throw err
}
const rMessages = request.messages || [] as Message[]
const inputArray = (rMessages).reduce((a, m) => {
const c = getRoleTag(m.role, model, chatRequest.chat) + m.content
a.push(c)
return a
}, [] as string[])
const lastMessage = rMessages[rMessages.length - 1]
if (lastMessage && lastMessage.role !== 'assistant') {
inputArray.push(getRoleTag('assistant', model, chatRequest.chat))
}
const petalsRequest = {
type: 'generate',
inputs: inputArray.join(stopSequence),
max_new_tokens: 3, // wait for up to 3 tokens before displaying
stop_sequence: stopSequence,
doSample: 1,
temperature: request.temperature || 0,
top_p: request.top_p || 0,
extra_stop_sequences: stopSequencesC
}
ws.send(JSON.stringify(petalsRequest))
ws.onmessage = event => {
// Remove updating indicator
chatRequest.updating = 1 // hide indicator, but still signal we're updating
chatRequest.updatingMessage = ''
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error in response: ' + response.traceback)
console.error(err)
throw err
}
window.setTimeout(() => {
chatResponse.updateFromAsyncResponse(
{
model,
choices: [{
delta: {
content: response.outputs,
role: 'assistant'
},
finish_reason: (response.stop ? 'stop' : null)
}]
} as any
)
if (response.stop) {
const message = chatResponse.getMessages()[0]
if (message) {
for (let i = 0, l = stopSequences.length; i < l; i++) {
if (message.content.endsWith(stopSequences[i])) {
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
const startS = startSequence[i] || ''
if (message.content.startsWith(startS)) message.content = message.content.slice(startS.length)
updateMessages(chatRequest.getChat().id)
}
}
}
}
}, 1)
}
}
ws.onclose = () => {
chatRequest.updating = false
chatRequest.updatingMessage = ''
chatResponse.updateFromClose()
}
ws.onerror = err => {
console.error(err)
throw err
}
}
}
</script>

View File

@ -3,7 +3,7 @@
import Footer from './Footer.svelte'
import { replace } from 'svelte-spa-router'
import { onMount } from 'svelte'
import { getPetalsV2Websocket } from './ApiUtil.svelte'
import { getPetals } from './ApiUtil.svelte'
$: apiKey = $apiKeyStorage
@ -112,7 +112,7 @@ const setPetalsEnabled = (event: Event) => {
aria-label="PetalsAPI Endpoint"
type="text"
class="input"
placeholder={getPetalsV2Websocket()}
placeholder={getPetals()}
value={$globalStorage.pedalsEndpoint || ''}
/>
</p>
@ -123,7 +123,7 @@ const setPetalsEnabled = (event: Event) => {
</form>
<p>
Only use <u>{getPetalsV2Websocket()}</u> for testing. You must set up your own Petals server for actual use.
Only use <u>{getPetals()}</u> for testing. You must set up your own Petals server for actual use.
</p>
<p>
<b>Do not send sensitive information when using Petals.</b>

View File

@ -1,5 +1,5 @@
<script context="module" lang="ts">
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetalsV2Websocket } from './ApiUtil.svelte'
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetals } from './ApiUtil.svelte'
import { apiKeyStorage, globalStorage } from './Storage.svelte'
import { get } from 'svelte/store'
import type { ModelDetail, Model, ResponseModels, SelectOption, ChatSettings } from './Types.svelte'
@ -34,9 +34,10 @@ const modelDetails : Record<string, ModelDetail> = {
max: 16384 // 16k max token buffer
},
'meta-llama/Llama-2-70b-chat-hf': {
type: 'PetalsV2Websocket',
type: 'Petals',
label: 'Petals - Llama-2-70b-chat',
stop: ['###', '</s>'],
start: [''],
stop: ['</s>'],
prompt: 0.000000, // $0.000 per 1000 tokens prompt
completion: 0.000000, // $0.000 per 1000 tokens completion
max: 4096 // 4k max token buffer
@ -119,8 +120,8 @@ export const getEndpoint = (model: Model): string => {
const modelDetails = getModelDetail(model)
const gSettings = get(globalStorage)
switch (modelDetails.type) {
case 'PetalsV2Websocket':
return gSettings.pedalsEndpoint || getPetalsV2Websocket()
case 'Petals':
return gSettings.pedalsEndpoint || getPetals()
case 'OpenAIDall-e':
return getApiBase() + getEndpointGenerations()
case 'OpenAIChat':
@ -132,12 +133,12 @@ export const getEndpoint = (model: Model): string => {
export const getRoleTag = (role: string, model: Model, settings: ChatSettings): string => {
const modelDetails = getModelDetail(model)
switch (modelDetails.type) {
case 'PetalsV2Websocket':
case 'Petals':
if (role === 'assistant') {
return ('Assistant') +
': '
if (settings.useSystemPrompt && settings.characterName) return '[' + settings.characterName + '] '
return '[Assistant] '
}
if (role === 'user') return 'Human: '
if (role === 'user') return '[user] '
return ''
case 'OpenAIDall-e':
return role
@ -150,7 +151,7 @@ export const getRoleTag = (role: string, model: Model, settings: ChatSettings):
export const getTokens = (model: Model, value: string): number[] => {
const modelDetails = getModelDetail(model)
switch (modelDetails.type) {
case 'PetalsV2Websocket':
case 'Petals':
return llamaTokenizer.encode(value)
case 'OpenAIDall-e':
return [0]
@ -184,7 +185,7 @@ export async function getModelOptions (): Promise<SelectOption[]> {
}
const filteredModels = supportedModelKeys.filter((model) => {
switch (getModelDetail(model).type) {
case 'PetalsV2Websocket':
case 'Petals':
return gSettings.enablePetals
case 'OpenAIChat':
default:

View File

@ -10,19 +10,12 @@
export const countPromptTokens = (prompts:Message[], model:Model, settings: ChatSettings):number => {
const detail = getModelDetail(model)
const count = prompts.reduce((a, m) => {
switch (detail.type) {
case 'PetalsV2Websocket':
a += countMessageTokens(m, model, settings)
break
case 'OpenAIChat':
default:
a += countMessageTokens(m, model, settings)
}
a += countMessageTokens(m, model, settings)
return a
}, 0)
switch (detail.type) {
case 'PetalsV2Websocket':
return count + (Math.max(prompts.length - 1, 0) * countTokens(model, (detail.stop && detail.stop[0]) || '###')) // todo, make stop per model?
case 'Petals':
return count
case 'OpenAIChat':
default:
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
@ -34,9 +27,11 @@
export const countMessageTokens = (message:Message, model:Model, settings: ChatSettings):number => {
const detail = getModelDetail(model)
const start = detail.start && detail.start[0]
const stop = detail.stop && detail.stop[0]
switch (detail.type) {
case 'PetalsV2Websocket':
return countTokens(model, getRoleTag(message.role, model, settings) + ': ' + message.content)
case 'Petals':
return countTokens(model, (start || '') + getRoleTag(message.role, model, settings) + ': ' + message.content + (stop || '###'))
case 'OpenAIChat':
default:
// Not sure how OpenAI formats it, but this seems to get close to the right counts.

View File

@ -7,12 +7,13 @@ export type Model = typeof supportedModelKeys[number];
export type ImageGenerationSizes = typeof imageGenerationSizeTypes[number];
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'PetalsV2Websocket'
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'Petals'
export type ModelDetail = {
type: RequestType;
label?: string;
stop?: string[];
start?: string[];
prompt: number;
completion: number;
max: number;