Initial test of Petals as alternative to OpenAI
This commit is contained in:
parent
8e35b198da
commit
914055f1f9
|
@ -27,6 +27,7 @@
|
|||
"eslint-plugin-svelte3": "^4.0.0",
|
||||
"flourite": "^1.2.4",
|
||||
"gpt-tokenizer": "^2.0.0",
|
||||
"llama-tokenizer-js": "^1.1.1",
|
||||
"postcss": "^8.4.26",
|
||||
"sass": "^1.63.6",
|
||||
"stacking-order": "^2.0.0",
|
||||
|
@ -3182,6 +3183,12 @@
|
|||
"node": ">= 0.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/llama-tokenizer-js": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/llama-tokenizer-js/-/llama-tokenizer-js-1.1.1.tgz",
|
||||
"integrity": "sha512-5H2oSJnSufWGhOw6hcCGAqJeB3POmeIBzRklH3cXs0L4MSAYdwoYTodni4j5YVo6jApdhaqaNVU66gNRgXeBRg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/locate-path": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz",
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
"eslint-plugin-svelte3": "^4.0.0",
|
||||
"flourite": "^1.2.4",
|
||||
"gpt-tokenizer": "^2.0.0",
|
||||
"llama-tokenizer-js": "^1.1.1",
|
||||
"postcss": "^8.4.26",
|
||||
"sass": "^1.63.6",
|
||||
"stacking-order": "^2.0.0",
|
||||
|
|
|
@ -5,10 +5,12 @@
|
|||
const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
|
||||
const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
|
||||
const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
|
||||
const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
|
||||
|
||||
export const getApiBase = ():string => apiBase
|
||||
export const getEndpointCompletions = ():string => endpointCompletions
|
||||
export const getEndpointGenerations = ():string => endpointGenerations
|
||||
export const getEndpointModels = ():string => endpointModels
|
||||
export const getEndpointEmbeddings = ():string => endpointEmbeddings
|
||||
export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket
|
||||
</script>
|
|
@ -1,9 +1,9 @@
|
|||
<script context="module" lang="ts">
|
||||
import { setImage } from './ImageStore.svelte'
|
||||
import { countTokens } from './Models.svelte'
|
||||
// TODO: Integrate API calls
|
||||
import { addMessage, getLatestKnownModel, saveChatStore, setLatestKnownModel, subtractRunningTotal, updateRunningTotal } from './Storage.svelte'
|
||||
import { addMessage, getLatestKnownModel, setLatestKnownModel, subtractRunningTotal, updateMessages, updateRunningTotal } from './Storage.svelte'
|
||||
import type { Chat, ChatCompletionOpts, ChatImage, Message, Model, Response, ResponseImage, Usage } from './Types.svelte'
|
||||
import { encode } from 'gpt-tokenizer'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
export class ChatCompletionResponse {
|
||||
|
@ -138,10 +138,10 @@ export class ChatCompletionResponse {
|
|||
message.content = this.initialFillMerge(message.content, choice.delta?.content)
|
||||
message.content += choice.delta.content
|
||||
}
|
||||
completionTokenCount += encode(message.content).length
|
||||
completionTokenCount += countTokens(this.model, message.content)
|
||||
message.model = response.model
|
||||
message.finish_reason = choice.finish_reason
|
||||
message.streaming = choice.finish_reason === null && !this.finished
|
||||
message.streaming = !choice.finish_reason && !this.finished
|
||||
this.messages[i] = message
|
||||
})
|
||||
// total up the tokens
|
||||
|
@ -209,10 +209,10 @@ export class ChatCompletionResponse {
|
|||
}
|
||||
|
||||
private finish = (): void => {
|
||||
this.messages.forEach(m => { m.streaming = false }) // make sure all are marked stopped
|
||||
updateMessages(this.chat.id)
|
||||
if (this.finished) return
|
||||
this.finished = true
|
||||
this.messages.forEach(m => { m.streaming = false }) // make sure all are marked stopped
|
||||
saveChatStore()
|
||||
const message = this.messages[0]
|
||||
const model = this.model || getLatestKnownModel(this.chat.settings.model)
|
||||
if (message) {
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
import { scrollToBottom, scrollToMessage } from './Util.svelte'
|
||||
import { getRequestSettingList, defaultModel } from './Settings.svelte'
|
||||
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
|
||||
import { getApiBase, getEndpointCompletions, getEndpointGenerations } from './ApiUtil.svelte'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { get } from 'svelte/store'
|
||||
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
|
||||
|
||||
export class ChatRequest {
|
||||
constructor () {
|
||||
|
@ -77,7 +77,7 @@ export class ChatRequest {
|
|||
const chatResponse = new ChatCompletionResponse(opts)
|
||||
|
||||
try {
|
||||
const response = await fetch(getApiBase() + getEndpointGenerations(), fetchOptions)
|
||||
const response = await fetch(getEndpoint('dall-e-' + size), fetchOptions)
|
||||
if (!response.ok) {
|
||||
await _this.handleError(response)
|
||||
} else {
|
||||
|
@ -206,7 +206,7 @@ export class ChatRequest {
|
|||
}
|
||||
|
||||
// Get token counts
|
||||
const promptTokenCount = countPromptTokens(messagePayload, model)
|
||||
const promptTokenCount = countPromptTokens(messagePayload, model, chatSettings)
|
||||
const maxAllowed = maxTokens - (promptTokenCount + 1)
|
||||
|
||||
// Build the API request body
|
||||
|
@ -245,6 +245,9 @@ export class ChatRequest {
|
|||
|
||||
// Set-up and make the request
|
||||
const chatResponse = new ChatCompletionResponse(opts)
|
||||
|
||||
const modelDetail = getModelDetail(model)
|
||||
|
||||
try {
|
||||
// Add out token count to the response handler
|
||||
// (streaming doesn't return counts, so we need to do it client side)
|
||||
|
@ -254,87 +257,193 @@ export class ChatRequest {
|
|||
// so we deal with it ourselves
|
||||
_this.controller = new AbortController()
|
||||
const signal = _this.controller.signal
|
||||
const abortListener = (e:Event) => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromError('User aborted request.')
|
||||
signal.removeEventListener('abort', abortListener)
|
||||
}
|
||||
signal.addEventListener('abort', abortListener)
|
||||
|
||||
const fetchOptions = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${getApiKey()}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal
|
||||
}
|
||||
|
||||
if (opts.streaming) {
|
||||
/**
|
||||
* Streaming request/response
|
||||
* We'll get the response a token at a time, as soon as they are ready
|
||||
*/
|
||||
if (modelDetail.type === 'PetalsV2Websocket') {
|
||||
// Petals
|
||||
const ws = new WebSocket(getEndpoint(model))
|
||||
const abortListener = (e:Event) => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromError('User aborted request.')
|
||||
signal.removeEventListener('abort', abortListener)
|
||||
ws.close()
|
||||
}
|
||||
signal.addEventListener('abort', abortListener)
|
||||
const stopSequences = modelDetail.stop || ['###']
|
||||
const stopSequencesC = stopSequences.slice()
|
||||
const stopSequence = stopSequencesC.shift()
|
||||
chatResponse.onFinish(() => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
})
|
||||
fetchEventSource(getApiBase() + getEndpointCompletions(), {
|
||||
...fetchOptions,
|
||||
openWhenHidden: true,
|
||||
onmessage (ev) {
|
||||
// Remove updating indicator
|
||||
_this.updating = 1 // hide indicator, but still signal we're updating
|
||||
_this.updatingMessage = ''
|
||||
// console.log('ev.data', ev.data)
|
||||
if (!chatResponse.hasFinished()) {
|
||||
if (ev.data === '[DONE]') {
|
||||
// ?? anything to do when "[DONE]"?
|
||||
} else {
|
||||
const data = JSON.parse(ev.data)
|
||||
// console.log('data', data)
|
||||
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
|
||||
}
|
||||
ws.onopen = () => {
|
||||
ws.send(JSON.stringify({
|
||||
type: 'open_inference_session',
|
||||
model,
|
||||
max_length: maxTokens || opts.maxTokens
|
||||
}))
|
||||
ws.onmessage = event => {
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error opening socket: ' + response.traceback)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
},
|
||||
onclose () {
|
||||
const rMessages = request.messages || [] as Message[]
|
||||
const inputArray = (rMessages).reduce((a, m) => {
|
||||
const c = getRoleTag(m.role, model, chatSettings) + m.content
|
||||
a.push(c)
|
||||
return a
|
||||
}, [] as string[])
|
||||
const lastMessage = rMessages[rMessages.length - 1]
|
||||
if (lastMessage && lastMessage.role !== 'assistant') {
|
||||
inputArray.push(getRoleTag('assistant', model, chatSettings))
|
||||
}
|
||||
const petalsRequest = {
|
||||
type: 'generate',
|
||||
inputs: (request.messages || [] as Message[]).reduce((a, m) => {
|
||||
const c = getRoleTag(m.role, model, chatSettings) + m.content
|
||||
a.push(c)
|
||||
return a
|
||||
}, [] as string[]).join(stopSequence),
|
||||
max_new_tokens: 3, // wait for up to 3 tokens before displaying
|
||||
stop_sequence: stopSequence,
|
||||
doSample: 1,
|
||||
temperature: request.temperature || 0,
|
||||
top_p: request.top_p || 0,
|
||||
extra_stop_sequences: stopSequencesC
|
||||
}
|
||||
ws.send(JSON.stringify(petalsRequest))
|
||||
ws.onmessage = event => {
|
||||
// Remove updating indicator
|
||||
_this.updating = 1 // hide indicator, but still signal we're updating
|
||||
_this.updatingMessage = ''
|
||||
const response = JSON.parse(event.data)
|
||||
if (!response.ok) {
|
||||
const err = new Error('Error in response: ' + response.traceback)
|
||||
console.error(err)
|
||||
throw err
|
||||
}
|
||||
window.setTimeout(() => {
|
||||
chatResponse.updateFromAsyncResponse(
|
||||
{
|
||||
model,
|
||||
choices: [{
|
||||
delta: {
|
||||
content: response.outputs,
|
||||
role: 'assistant'
|
||||
},
|
||||
finish_reason: (response.stop ? 'stop' : null)
|
||||
}]
|
||||
} as any
|
||||
)
|
||||
if (response.stop) {
|
||||
const message = chatResponse.getMessages()[0]
|
||||
if (message) {
|
||||
for (let i = 0, l = stopSequences.length; i < l; i++) {
|
||||
if (message.content.endsWith(stopSequences[i])) {
|
||||
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
|
||||
updateMessages(chatId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 1)
|
||||
}
|
||||
}
|
||||
ws.onclose = () => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromClose()
|
||||
},
|
||||
onerror (err) {
|
||||
}
|
||||
ws.onerror = err => {
|
||||
console.error(err)
|
||||
throw err
|
||||
},
|
||||
async onopen (response) {
|
||||
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
|
||||
// everything's good
|
||||
} else {
|
||||
// client-side errors are usually non-retriable:
|
||||
await _this.handleError(response)
|
||||
}
|
||||
}
|
||||
}).catch(err => {
|
||||
}
|
||||
} else {
|
||||
// OpenAI
|
||||
const abortListener = (e:Event) => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromError(err.message)
|
||||
})
|
||||
} else {
|
||||
chatResponse.updateFromError('User aborted request.')
|
||||
signal.removeEventListener('abort', abortListener)
|
||||
}
|
||||
signal.addEventListener('abort', abortListener)
|
||||
const fetchOptions = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${getApiKey()}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
signal
|
||||
}
|
||||
|
||||
if (opts.streaming) {
|
||||
/**
|
||||
* Streaming request/response
|
||||
* We'll get the response a token at a time, as soon as they are ready
|
||||
*/
|
||||
chatResponse.onFinish(() => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
})
|
||||
fetchEventSource(getEndpoint(model), {
|
||||
...fetchOptions,
|
||||
openWhenHidden: true,
|
||||
onmessage (ev) {
|
||||
// Remove updating indicator
|
||||
_this.updating = 1 // hide indicator, but still signal we're updating
|
||||
_this.updatingMessage = ''
|
||||
// console.log('ev.data', ev.data)
|
||||
if (!chatResponse.hasFinished()) {
|
||||
if (ev.data === '[DONE]') {
|
||||
// ?? anything to do when "[DONE]"?
|
||||
} else {
|
||||
const data = JSON.parse(ev.data)
|
||||
// console.log('data', data)
|
||||
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
|
||||
}
|
||||
}
|
||||
},
|
||||
onclose () {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromClose()
|
||||
},
|
||||
onerror (err) {
|
||||
console.error(err)
|
||||
throw err
|
||||
},
|
||||
async onopen (response) {
|
||||
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
|
||||
// everything's good
|
||||
} else {
|
||||
// client-side errors are usually non-retriable:
|
||||
await _this.handleError(response)
|
||||
}
|
||||
}
|
||||
}).catch(err => {
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromError(err.message)
|
||||
})
|
||||
} else {
|
||||
/**
|
||||
* Non-streaming request/response
|
||||
* We'll get the response all at once, after a long delay
|
||||
*/
|
||||
const response = await fetch(getApiBase() + getEndpointCompletions(), fetchOptions)
|
||||
if (!response.ok) {
|
||||
await _this.handleError(response)
|
||||
} else {
|
||||
const json = await response.json()
|
||||
// Remove updating indicator
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromSyncResponse(json)
|
||||
const response = await fetch(getEndpoint(model), fetchOptions)
|
||||
if (!response.ok) {
|
||||
await _this.handleError(response)
|
||||
} else {
|
||||
const json = await response.json()
|
||||
// Remove updating indicator
|
||||
_this.updating = false
|
||||
_this.updatingMessage = ''
|
||||
chatResponse.updateFromSyncResponse(json)
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
|
@ -393,11 +502,11 @@ export class ChatRequest {
|
|||
* Gets an estimate of how many extra tokens will be added that won't be part of the visible messages
|
||||
* @param filtered
|
||||
*/
|
||||
private getTokenCountPadding (filtered: Message[]): number {
|
||||
private getTokenCountPadding (filtered: Message[], settings: ChatSettings): number {
|
||||
let result = 0
|
||||
// add cost of hiddenPromptPrefix
|
||||
result += this.buildHiddenPromptPrefixMessages(filtered)
|
||||
.reduce((a, m) => a + countMessageTokens(m, this.getModel()), 0)
|
||||
.reduce((a, m) => a + countMessageTokens(m, this.getModel(), settings), 0)
|
||||
// more here eventually?
|
||||
return result
|
||||
}
|
||||
|
@ -419,10 +528,10 @@ export class ChatRequest {
|
|||
}
|
||||
|
||||
// Get extra counts for when the prompts are finally sent.
|
||||
const countPadding = this.getTokenCountPadding(filtered)
|
||||
const countPadding = this.getTokenCountPadding(filtered, chatSettings)
|
||||
|
||||
// See if we have enough to apply any of the reduction modes
|
||||
const fullPromptSize = countPromptTokens(filtered, model) + countPadding
|
||||
const fullPromptSize = countPromptTokens(filtered, model, chatSettings) + countPadding
|
||||
if (fullPromptSize < chatSettings.summaryThreshold) return await continueRequest() // nothing to do yet
|
||||
const overMax = fullPromptSize > maxTokens * 0.95
|
||||
|
||||
|
@ -445,12 +554,12 @@ export class ChatRequest {
|
|||
* *************************************************************
|
||||
*/
|
||||
|
||||
let promptSize = countPromptTokens(top.concat(rw), model) + countPadding
|
||||
let promptSize = countPromptTokens(top.concat(rw), model, chatSettings) + countPadding
|
||||
while (rw.length && rw.length > pinBottom && promptSize >= chatSettings.summaryThreshold) {
|
||||
const rolled = rw.shift()
|
||||
// Hide messages we're "rolling"
|
||||
if (rolled) rolled.suppress = true
|
||||
promptSize = countPromptTokens(top.concat(rw), model) + countPadding
|
||||
promptSize = countPromptTokens(top.concat(rw), model, chatSettings) + countPadding
|
||||
}
|
||||
// Run a new request, now with the rolled messages hidden
|
||||
return await _this.sendRequest(get(currentChatMessages), {
|
||||
|
@ -466,26 +575,26 @@ export class ChatRequest {
|
|||
const bottom = rw.slice(0 - pinBottom)
|
||||
let continueCounter = chatSettings.summaryExtend + 1
|
||||
rw = rw.slice(0, 0 - pinBottom)
|
||||
let reductionPoolSize = countPromptTokens(rw, model)
|
||||
let reductionPoolSize = countPromptTokens(rw, model, chatSettings)
|
||||
const ss = Math.abs(chatSettings.summarySize)
|
||||
const getSS = ():number => (ss < 1 && ss > 0)
|
||||
? Math.round(reductionPoolSize * ss) // If summarySize between 0 and 1, use percentage of reduced
|
||||
: Math.min(ss, reductionPoolSize * 0.5) // If > 1, use token count
|
||||
const topSize = countPromptTokens(top, model)
|
||||
const topSize = countPromptTokens(top, model, chatSettings)
|
||||
let maxSummaryTokens = getSS()
|
||||
let promptSummary = prepareSummaryPrompt(chatId, maxSummaryTokens)
|
||||
const summaryRequest = { role: 'user', content: promptSummary } as Message
|
||||
let promptSummarySize = countMessageTokens(summaryRequest, model)
|
||||
let promptSummarySize = countMessageTokens(summaryRequest, model, chatSettings)
|
||||
// Make sure there is enough room to generate the summary, and try to make sure
|
||||
// the last prompt is a user prompt as that seems to work better for summaries
|
||||
while ((topSize + reductionPoolSize + promptSummarySize + maxSummaryTokens) >= maxTokens ||
|
||||
(reductionPoolSize >= 100 && rw[rw.length - 1]?.role !== 'user')) {
|
||||
bottom.unshift(rw.pop() as Message)
|
||||
reductionPoolSize = countPromptTokens(rw, model)
|
||||
reductionPoolSize = countPromptTokens(rw, model, chatSettings)
|
||||
maxSummaryTokens = getSS()
|
||||
promptSummary = prepareSummaryPrompt(chatId, maxSummaryTokens)
|
||||
summaryRequest.content = promptSummary
|
||||
promptSummarySize = countMessageTokens(summaryRequest, model)
|
||||
promptSummarySize = countMessageTokens(summaryRequest, model, chatSettings)
|
||||
}
|
||||
if (reductionPoolSize < 50) {
|
||||
if (overMax) addError(chatId, 'Check summary settings. Unable to summarize enough messages.')
|
||||
|
@ -571,10 +680,10 @@ export class ChatRequest {
|
|||
// Try to get more of it
|
||||
delete summaryResponse.finish_reason
|
||||
_this.updatingMessage = 'Summarizing more...'
|
||||
let _recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model)
|
||||
let _recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model, chatSettings)
|
||||
while (rw.length && (_recount + maxSummaryTokens >= maxTokens)) {
|
||||
rw.shift()
|
||||
_recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model)
|
||||
_recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model, chatSettings)
|
||||
}
|
||||
loopCount++
|
||||
continue
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import { getChatDefaults, getChatSettingList, getChatSettingObjectByKey, getExcludeFromProfile } from './Settings.svelte'
|
||||
import {
|
||||
saveChatStore,
|
||||
apiKeyStorage,
|
||||
chatsStorage,
|
||||
globalStorage,
|
||||
saveCustomProfile,
|
||||
|
@ -13,7 +12,7 @@
|
|||
checkStateChange,
|
||||
addChat
|
||||
} from './Storage.svelte'
|
||||
import type { Chat, ChatSetting, ResponseModels, SettingSelect, SelectOption, ChatSettings } from './Types.svelte'
|
||||
import type { Chat, ChatSetting, SettingSelect, ChatSettings } from './Types.svelte'
|
||||
import { errorNotice, sizeTextElements } from './Util.svelte'
|
||||
import Fa from 'svelte-fa/src/fa.svelte'
|
||||
import {
|
||||
|
@ -35,8 +34,7 @@
|
|||
import { replace } from 'svelte-spa-router'
|
||||
import { openModal } from 'svelte-modals'
|
||||
import PromptConfirm from './PromptConfirm.svelte'
|
||||
import { getApiBase, getEndpointModels } from './ApiUtil.svelte'
|
||||
import { supportedModelKeys } from './Models.svelte'
|
||||
import { getModelOptions } from './Models.svelte'
|
||||
|
||||
export let chatId:number
|
||||
export const show = () => { showSettings() }
|
||||
|
@ -185,30 +183,9 @@
|
|||
// Refresh settings modal
|
||||
showSettingsModal++
|
||||
|
||||
// Load available models from OpenAI
|
||||
const allModels = (await (
|
||||
await fetch(getApiBase() + getEndpointModels(), {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${$apiKeyStorage}`,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
})
|
||||
).json()) as ResponseModels
|
||||
const filteredModels = supportedModelKeys.filter((model) => allModels.data.find((m) => m.id === model))
|
||||
|
||||
const modelOptions:SelectOption[] = filteredModels.reduce((a, m) => {
|
||||
const o:SelectOption = {
|
||||
value: m,
|
||||
text: m
|
||||
}
|
||||
a.push(o)
|
||||
return a
|
||||
}, [] as SelectOption[])
|
||||
|
||||
// Update the models in the settings
|
||||
if (modelSetting) {
|
||||
modelSetting.options = modelOptions
|
||||
modelSetting.options = await getModelOptions()
|
||||
}
|
||||
// Refresh settings modal
|
||||
showSettingsModal++
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
<script lang="ts">
|
||||
import { apiKeyStorage, lastChatId, getChat, started } from './Storage.svelte'
|
||||
import { apiKeyStorage, globalStorage, lastChatId, getChat, started, setGlobalSettingValueByKey } from './Storage.svelte'
|
||||
import Footer from './Footer.svelte'
|
||||
import { replace } from 'svelte-spa-router'
|
||||
import { onMount } from 'svelte'
|
||||
import { getPetalsV2Websocket } from './ApiUtil.svelte'
|
||||
|
||||
$: apiKey = $apiKeyStorage
|
||||
|
||||
let showPetalsSettings = $globalStorage.enablePetals
|
||||
|
||||
onMount(() => {
|
||||
if (!$started) {
|
||||
$started = true
|
||||
|
@ -19,6 +22,12 @@ onMount(() => {
|
|||
$lastChatId = 0
|
||||
})
|
||||
|
||||
const setPetalsEnabled = (event: Event) => {
|
||||
const el = (event.target as HTMLInputElement)
|
||||
setGlobalSettingValueByKey('enablePetals', !!el.checked)
|
||||
showPetalsSettings = $globalStorage.enablePetals
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<section class="section">
|
||||
|
@ -60,6 +69,8 @@ onMount(() => {
|
|||
<p class="control">
|
||||
<button class="button is-info" type="submit">Save</button>
|
||||
</p>
|
||||
|
||||
|
||||
</form>
|
||||
|
||||
{#if !apiKey}
|
||||
|
@ -70,6 +81,66 @@ onMount(() => {
|
|||
{/if}
|
||||
</div>
|
||||
</article>
|
||||
|
||||
|
||||
<article class="message" class:is-info={true}>
|
||||
<div class="message-body">
|
||||
<label class="label" for="enablePetals">
|
||||
<input
|
||||
type="checkbox"
|
||||
class="checkbox"
|
||||
id="enablePetals"
|
||||
checked={!!$globalStorage.enablePetals}
|
||||
on:click={setPetalsEnabled}
|
||||
>
|
||||
Use Petals API and Models
|
||||
</label>
|
||||
{#if showPetalsSettings}
|
||||
<p>Set Petals API Endpoint:</p>
|
||||
<form
|
||||
class="field has-addons has-addons-right"
|
||||
on:submit|preventDefault={(event) => {
|
||||
if (event.target && event.target[0].value) {
|
||||
setGlobalSettingValueByKey('pedalsEndpoint', (event.target[0].value).trim())
|
||||
} else {
|
||||
setGlobalSettingValueByKey('pedalsEndpoint', '')
|
||||
}
|
||||
}}
|
||||
>
|
||||
<p class="control is-expanded">
|
||||
<input
|
||||
aria-label="PetalsAPI Endpoint"
|
||||
type="text"
|
||||
class="input"
|
||||
placeholder={getPetalsV2Websocket()}
|
||||
value={$globalStorage.pedalsEndpoint || ''}
|
||||
/>
|
||||
</p>
|
||||
<p class="control">
|
||||
<button class="button is-info" type="submit">Save</button>
|
||||
</p>
|
||||
|
||||
|
||||
</form>
|
||||
<p>
|
||||
Only use <u>{getPetalsV2Websocket()}</u> for testing. You must set up your own Petals server for actual use.
|
||||
</p>
|
||||
<p>
|
||||
<b>Do not send sensitive information when using Petals.</b>
|
||||
</p>
|
||||
<p>
|
||||
For more information on Petals, see
|
||||
<a href="https://github.com/petals-infra/chat.petals.dev">https://github.com/petals-infra/chat.petals.dev</a>
|
||||
</p>
|
||||
{/if}
|
||||
{#if !apiKey}
|
||||
<p class="help is-danger">
|
||||
Please enter your <a href="https://platform.openai.com/account/api-keys">OpenAI API key</a> above to use ChatGPT-web.
|
||||
It is required to use ChatGPT-web.
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
</article>
|
||||
{#if apiKey}
|
||||
<article class="message is-info">
|
||||
<div class="message-body">
|
||||
|
|
|
@ -1,43 +1,63 @@
|
|||
<script context="module" lang="ts">
|
||||
import type { ModelDetail, Model } from './Types.svelte'
|
||||
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetalsV2Websocket } from './ApiUtil.svelte'
|
||||
import { apiKeyStorage, globalStorage } from './Storage.svelte'
|
||||
import { get } from 'svelte/store'
|
||||
import type { ModelDetail, Model, ResponseModels, SelectOption, ChatSettings } from './Types.svelte'
|
||||
import { encode } from 'gpt-tokenizer'
|
||||
import llamaTokenizer from 'llama-tokenizer-js'
|
||||
|
||||
// Reference: https://openai.com/pricing#language-models
|
||||
// Eventually we'll add API hosts and endpoints to this
|
||||
const modelDetails : Record<string, ModelDetail> = {
|
||||
'gpt-4-32k': {
|
||||
type: 'OpenAIChat',
|
||||
prompt: 0.00006, // $0.06 per 1000 tokens prompt
|
||||
completion: 0.00012, // $0.12 per 1000 tokens completion
|
||||
max: 32768 // 32k max token buffer
|
||||
},
|
||||
'gpt-4': {
|
||||
type: 'OpenAIChat',
|
||||
prompt: 0.00003, // $0.03 per 1000 tokens prompt
|
||||
completion: 0.00006, // $0.06 per 1000 tokens completion
|
||||
max: 8192 // 8k max token buffer
|
||||
},
|
||||
'gpt-3.5': {
|
||||
type: 'OpenAIChat',
|
||||
prompt: 0.0000015, // $0.0015 per 1000 tokens prompt
|
||||
completion: 0.000002, // $0.002 per 1000 tokens completion
|
||||
max: 4096 // 4k max token buffer
|
||||
},
|
||||
'gpt-3.5-turbo-16k': {
|
||||
type: 'OpenAIChat',
|
||||
prompt: 0.000003, // $0.003 per 1000 tokens prompt
|
||||
completion: 0.000004, // $0.004 per 1000 tokens completion
|
||||
max: 16384 // 16k max token buffer
|
||||
},
|
||||
'meta-llama/Llama-2-70b-chat-hf': {
|
||||
type: 'PetalsV2Websocket',
|
||||
label: 'Petals - Llama-2-70b-chat',
|
||||
stop: ['###', '</s>'],
|
||||
prompt: 0.000000, // $0.000 per 1000 tokens prompt
|
||||
completion: 0.000000, // $0.000 per 1000 tokens completion
|
||||
max: 4096 // 4k max token buffer
|
||||
}
|
||||
}
|
||||
|
||||
const imageModels : Record<string, ModelDetail> = {
|
||||
export const imageModels : Record<string, ModelDetail> = {
|
||||
'dall-e-1024x1024': {
|
||||
type: 'OpenAIDall-e',
|
||||
prompt: 0.00,
|
||||
completion: 0.020, // $0.020 per image
|
||||
max: 1000 // 1000 char prompt, max
|
||||
},
|
||||
'dall-e-512x512': {
|
||||
type: 'OpenAIDall-e',
|
||||
prompt: 0.00,
|
||||
completion: 0.018, // $0.018 per image
|
||||
max: 1000 // 1000 char prompt, max
|
||||
},
|
||||
'dall-e-256x256': {
|
||||
type: 'OpenAIDall-e',
|
||||
prompt: 0.00,
|
||||
completion: 0.016, // $0.016 per image
|
||||
max: 1000 // 1000 char prompt, max
|
||||
|
@ -47,8 +67,9 @@ const imageModels : Record<string, ModelDetail> = {
|
|||
const unknownDetail = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
max: 4096
|
||||
}
|
||||
max: 4096,
|
||||
type: 'OpenAIChat'
|
||||
} as ModelDetail
|
||||
|
||||
// See: https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
// Eventually we'll add UI for managing this
|
||||
|
@ -62,7 +83,8 @@ export const supportedModels : Record<string, ModelDetail> = {
|
|||
'gpt-3.5-turbo': modelDetails['gpt-3.5'],
|
||||
'gpt-3.5-turbo-16k': modelDetails['gpt-3.5-turbo-16k'],
|
||||
'gpt-3.5-turbo-0301': modelDetails['gpt-3.5'],
|
||||
'gpt-3.5-turbo-0613': modelDetails['gpt-3.5']
|
||||
'gpt-3.5-turbo-0613': modelDetails['gpt-3.5'],
|
||||
'meta-llama/Llama-2-70b-chat-hf': modelDetails['meta-llama/Llama-2-70b-chat-hf']
|
||||
}
|
||||
|
||||
const lookupList = {
|
||||
|
@ -75,7 +97,7 @@ export const supportedModelKeys = Object.keys({ ...supportedModels, ...imageMode
|
|||
|
||||
const tpCache : Record<string, ModelDetail> = {}
|
||||
|
||||
export const getModelDetail = (model: Model) => {
|
||||
export const getModelDetail = (model: Model): ModelDetail => {
|
||||
// First try to get exact match, then from cache
|
||||
let r = supportedModels[model] || tpCache[model]
|
||||
if (r) return r
|
||||
|
@ -93,4 +115,93 @@ export const getModelDetail = (model: Model) => {
|
|||
return r
|
||||
}
|
||||
|
||||
export const getEndpoint = (model: Model): string => {
|
||||
const modelDetails = getModelDetail(model)
|
||||
const gSettings = get(globalStorage)
|
||||
switch (modelDetails.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return gSettings.pedalsEndpoint || getPetalsV2Websocket()
|
||||
case 'OpenAIDall-e':
|
||||
return getApiBase() + getEndpointGenerations()
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
return gSettings.openAICompletionEndpoint || (getApiBase() + getEndpointCompletions())
|
||||
}
|
||||
}
|
||||
|
||||
export const getRoleTag = (role: string, model: Model, settings: ChatSettings): string => {
|
||||
const modelDetails = getModelDetail(model)
|
||||
switch (modelDetails.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
if (role === 'assistant') {
|
||||
return ('Assistant') +
|
||||
': '
|
||||
}
|
||||
if (role === 'user') return 'Human: '
|
||||
return ''
|
||||
case 'OpenAIDall-e':
|
||||
return role
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
return role
|
||||
}
|
||||
}
|
||||
|
||||
export const getTokens = (model: Model, value: string): number[] => {
|
||||
const modelDetails = getModelDetail(model)
|
||||
switch (modelDetails.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return llamaTokenizer.encode(value)
|
||||
case 'OpenAIDall-e':
|
||||
return [0]
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
return encode(value)
|
||||
}
|
||||
}
|
||||
|
||||
export const countTokens = (model: Model, value: string): number => {
|
||||
return getTokens(model, value).length
|
||||
}
|
||||
|
||||
export async function getModelOptions (): Promise<SelectOption[]> {
|
||||
const gSettings = get(globalStorage)
|
||||
const openAiKey = get(apiKeyStorage)
|
||||
// Load available models from OpenAI
|
||||
let openAiModels
|
||||
try {
|
||||
openAiModels = (await (
|
||||
await fetch(getApiBase() + getEndpointModels(), {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openAiKey}`,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
})
|
||||
).json()) as ResponseModels
|
||||
} catch (e) {
|
||||
openAiModels = { data: [] }
|
||||
}
|
||||
const filteredModels = supportedModelKeys.filter((model) => {
|
||||
switch (getModelDetail(model).type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return gSettings.enablePetals
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
return openAiModels.data.find((m) => m.id === model)
|
||||
}
|
||||
})
|
||||
|
||||
const modelOptions:SelectOption[] = filteredModels.reduce((a, m) => {
|
||||
const o:SelectOption = {
|
||||
value: m,
|
||||
text: m
|
||||
}
|
||||
a.push(o)
|
||||
return a
|
||||
}, [] as SelectOption[])
|
||||
|
||||
return modelOptions
|
||||
}
|
||||
|
||||
</script>
|
|
@ -1,7 +1,6 @@
|
|||
<script context="module" lang="ts">
|
||||
import { applyProfile } from './Profiles.svelte'
|
||||
import { getChatSettings, getGlobalSettings, setGlobalSettingValueByKey } from './Storage.svelte'
|
||||
import { encode } from 'gpt-tokenizer'
|
||||
import { faArrowDown91, faArrowDownAZ, faCheck, faThumbTack } from '@fortawesome/free-solid-svg-icons/index'
|
||||
// Setting definitions
|
||||
|
||||
|
@ -18,6 +17,7 @@ import {
|
|||
type ChatSortOption
|
||||
|
||||
} from './Types.svelte'
|
||||
import { getTokens } from './Models.svelte'
|
||||
|
||||
export const defaultModel:Model = 'gpt-3.5-turbo'
|
||||
|
||||
|
@ -104,7 +104,10 @@ export const globalDefaults: GlobalSettings = {
|
|||
lastProfile: 'default',
|
||||
defaultProfile: 'default',
|
||||
hideSummarized: false,
|
||||
chatSort: 'created'
|
||||
chatSort: 'created',
|
||||
openAICompletionEndpoint: '',
|
||||
enablePetals: false,
|
||||
pedalsEndpoint: ''
|
||||
}
|
||||
|
||||
const excludeFromProfile = {
|
||||
|
@ -497,7 +500,7 @@ const chatSettingsList: ChatSetting[] = [
|
|||
// console.log('logit_bias', val, getChatSettings(chatId).logit_bias)
|
||||
if (!val) return null
|
||||
const tokenized:Record<number, number> = Object.entries(val).reduce((a, [k, v]) => {
|
||||
const tokens:number[] = encode(k)
|
||||
const tokens:number[] = getTokens(getChatSettings(chatId).model, k)
|
||||
tokens.forEach(t => { a[t] = v })
|
||||
return a
|
||||
}, {} as Record<number, number>)
|
||||
|
@ -536,6 +539,21 @@ const globalSettingsList:GlobalSetting[] = [
|
|||
key: 'hideSummarized',
|
||||
name: 'Hide Summarized Messages',
|
||||
type: 'boolean'
|
||||
},
|
||||
{
|
||||
key: 'openAICompletionEndpoint',
|
||||
name: 'OpenAI Completions Endpoint',
|
||||
type: 'text'
|
||||
},
|
||||
{
|
||||
key: 'enablePetals',
|
||||
name: 'Enable Petals APIs',
|
||||
type: 'boolean'
|
||||
},
|
||||
{
|
||||
key: 'pedalsEndpoint',
|
||||
name: 'Petals API Endpoint',
|
||||
type: 'text'
|
||||
}
|
||||
]
|
||||
|
||||
|
|
|
@ -1,25 +1,49 @@
|
|||
<script context="module" lang="ts">
|
||||
import { getModelDetail } from './Models.svelte'
|
||||
import type { Message, Model, Usage } from './Types.svelte'
|
||||
import { encode } from 'gpt-tokenizer'
|
||||
import { countTokens, getModelDetail, getRoleTag } from './Models.svelte'
|
||||
import type { ChatSettings, Message, Model, Usage } from './Types.svelte'
|
||||
|
||||
export const getPrice = (tokens: Usage, model: Model): number => {
|
||||
const t = getModelDetail(model)
|
||||
return ((tokens.prompt_tokens * t.prompt) + (tokens.completion_tokens * t.completion))
|
||||
}
|
||||
|
||||
export const countPromptTokens = (prompts:Message[], model:Model):number => {
|
||||
return prompts.reduce((a, m) => {
|
||||
a += countMessageTokens(m, model)
|
||||
export const countPromptTokens = (prompts:Message[], model:Model, settings: ChatSettings):number => {
|
||||
const detail = getModelDetail(model)
|
||||
const count = prompts.reduce((a, m) => {
|
||||
switch (detail.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
a += countMessageTokens(m, model, settings)
|
||||
break
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
a += countMessageTokens(m, model, settings)
|
||||
}
|
||||
return a
|
||||
}, 0) + 3 // Always seems to be message counts + 3
|
||||
}, 0)
|
||||
switch (detail.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return count + (Math.max(prompts.length - 1, 0) * countTokens(model, (detail.stop && detail.stop[0]) || '###')) // todo, make stop per model?
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
||||
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
|
||||
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
|
||||
return count + 3 // Always seems to be message counts + 3
|
||||
}
|
||||
}
|
||||
|
||||
export const countMessageTokens = (message:Message, model:Model):number => {
|
||||
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
||||
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
|
||||
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
|
||||
return encode('## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n').length
|
||||
export const countMessageTokens = (message:Message, model:Model, settings: ChatSettings):number => {
|
||||
const detail = getModelDetail(model)
|
||||
switch (detail.type) {
|
||||
case 'PetalsV2Websocket':
|
||||
return countTokens(model, getRoleTag(message.role, model, settings) + ': ' + message.content)
|
||||
case 'OpenAIChat':
|
||||
default:
|
||||
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
|
||||
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
|
||||
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
|
||||
return countTokens(model, '## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n')
|
||||
}
|
||||
}
|
||||
|
||||
export const getModelMaxTokens = (model:Model):number => {
|
||||
|
|
|
@ -7,7 +7,12 @@ export type Model = typeof supportedModelKeys[number];
|
|||
|
||||
export type ImageGenerationSizes = typeof imageGenerationSizeTypes[number];
|
||||
|
||||
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'PetalsV2Websocket'
|
||||
|
||||
export type ModelDetail = {
|
||||
type: RequestType;
|
||||
label?: string;
|
||||
stop?: string[];
|
||||
prompt: number;
|
||||
completion: number;
|
||||
max: number;
|
||||
|
@ -122,16 +127,16 @@ export type Chat = {
|
|||
};
|
||||
|
||||
type ResponseOK = {
|
||||
id: string;
|
||||
object: string;
|
||||
created: number;
|
||||
choices: {
|
||||
index: number;
|
||||
id?: string;
|
||||
object?: string;
|
||||
created?: number;
|
||||
choices?: {
|
||||
index?: number;
|
||||
message: Message;
|
||||
finish_reason: string;
|
||||
finish_reason?: string;
|
||||
delta: Message;
|
||||
}[];
|
||||
usage: Usage;
|
||||
usage?: Usage;
|
||||
model: Model;
|
||||
};
|
||||
|
||||
|
@ -172,6 +177,9 @@ export type GlobalSettings = {
|
|||
defaultProfile: string;
|
||||
hideSummarized: boolean;
|
||||
chatSort: ChatSortOptions;
|
||||
openAICompletionEndpoint: string;
|
||||
enablePetals: boolean;
|
||||
pedalsEndpoint: string;
|
||||
};
|
||||
|
||||
type SettingNumber = {
|
||||
|
|
Loading…
Reference in New Issue