Initial test of Petals as alternative to OpenAI

This commit is contained in:
Webifi 2023-07-22 01:42:21 -05:00
parent 8e35b198da
commit 914055f1f9
11 changed files with 469 additions and 141 deletions

7
package-lock.json generated
View File

@ -27,6 +27,7 @@
"eslint-plugin-svelte3": "^4.0.0",
"flourite": "^1.2.4",
"gpt-tokenizer": "^2.0.0",
"llama-tokenizer-js": "^1.1.1",
"postcss": "^8.4.26",
"sass": "^1.63.6",
"stacking-order": "^2.0.0",
@ -3182,6 +3183,12 @@
"node": ">= 0.8.0"
}
},
"node_modules/llama-tokenizer-js": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/llama-tokenizer-js/-/llama-tokenizer-js-1.1.1.tgz",
"integrity": "sha512-5H2oSJnSufWGhOw6hcCGAqJeB3POmeIBzRklH3cXs0L4MSAYdwoYTodni4j5YVo6jApdhaqaNVU66gNRgXeBRg==",
"dev": true
},
"node_modules/locate-path": {
"version": "6.0.0",
"resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz",

View File

@ -33,6 +33,7 @@
"eslint-plugin-svelte3": "^4.0.0",
"flourite": "^1.2.4",
"gpt-tokenizer": "^2.0.0",
"llama-tokenizer-js": "^1.1.1",
"postcss": "^8.4.26",
"sass": "^1.63.6",
"stacking-order": "^2.0.0",

View File

@ -5,10 +5,12 @@
const endpointGenerations = import.meta.env.VITE_ENDPOINT_GENERATIONS || '/v1/images/generations'
const endpointModels = import.meta.env.VITE_ENDPOINT_MODELS || '/v1/models'
const endpointEmbeddings = import.meta.env.VITE_ENDPOINT_EMBEDDINGS || '/v1/embeddings'
const endpointPetalsV2Websocket = import.meta.env.VITE_PEDALS_WEBSOCKET || 'wss://chat.petals.dev/api/v2/generate'
export const getApiBase = ():string => apiBase
export const getEndpointCompletions = ():string => endpointCompletions
export const getEndpointGenerations = ():string => endpointGenerations
export const getEndpointModels = ():string => endpointModels
export const getEndpointEmbeddings = ():string => endpointEmbeddings
export const getPetalsV2Websocket = ():string => endpointPetalsV2Websocket
</script>

View File

@ -1,9 +1,9 @@
<script context="module" lang="ts">
import { setImage } from './ImageStore.svelte'
import { countTokens } from './Models.svelte'
// TODO: Integrate API calls
import { addMessage, getLatestKnownModel, saveChatStore, setLatestKnownModel, subtractRunningTotal, updateRunningTotal } from './Storage.svelte'
import { addMessage, getLatestKnownModel, setLatestKnownModel, subtractRunningTotal, updateMessages, updateRunningTotal } from './Storage.svelte'
import type { Chat, ChatCompletionOpts, ChatImage, Message, Model, Response, ResponseImage, Usage } from './Types.svelte'
import { encode } from 'gpt-tokenizer'
import { v4 as uuidv4 } from 'uuid'
export class ChatCompletionResponse {
@ -138,10 +138,10 @@ export class ChatCompletionResponse {
message.content = this.initialFillMerge(message.content, choice.delta?.content)
message.content += choice.delta.content
}
completionTokenCount += encode(message.content).length
completionTokenCount += countTokens(this.model, message.content)
message.model = response.model
message.finish_reason = choice.finish_reason
message.streaming = choice.finish_reason === null && !this.finished
message.streaming = !choice.finish_reason && !this.finished
this.messages[i] = message
})
// total up the tokens
@ -209,10 +209,10 @@ export class ChatCompletionResponse {
}
private finish = (): void => {
this.messages.forEach(m => { m.streaming = false }) // make sure all are marked stopped
updateMessages(this.chat.id)
if (this.finished) return
this.finished = true
this.messages.forEach(m => { m.streaming = false }) // make sure all are marked stopped
saveChatStore()
const message = this.messages[0]
const model = this.model || getLatestKnownModel(this.chat.settings.model)
if (message) {

View File

@ -7,9 +7,9 @@
import { scrollToBottom, scrollToMessage } from './Util.svelte'
import { getRequestSettingList, defaultModel } from './Settings.svelte'
import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'
import { getApiBase, getEndpointCompletions, getEndpointGenerations } from './ApiUtil.svelte'
import { v4 as uuidv4 } from 'uuid'
import { get } from 'svelte/store'
import { getEndpoint, getModelDetail, getRoleTag } from './Models.svelte'
export class ChatRequest {
constructor () {
@ -77,7 +77,7 @@ export class ChatRequest {
const chatResponse = new ChatCompletionResponse(opts)
try {
const response = await fetch(getApiBase() + getEndpointGenerations(), fetchOptions)
const response = await fetch(getEndpoint('dall-e-' + size), fetchOptions)
if (!response.ok) {
await _this.handleError(response)
} else {
@ -206,7 +206,7 @@ export class ChatRequest {
}
// Get token counts
const promptTokenCount = countPromptTokens(messagePayload, model)
const promptTokenCount = countPromptTokens(messagePayload, model, chatSettings)
const maxAllowed = maxTokens - (promptTokenCount + 1)
// Build the API request body
@ -245,6 +245,9 @@ export class ChatRequest {
// Set-up and make the request
const chatResponse = new ChatCompletionResponse(opts)
const modelDetail = getModelDetail(model)
try {
// Add out token count to the response handler
// (streaming doesn't return counts, so we need to do it client side)
@ -254,87 +257,193 @@ export class ChatRequest {
// so we deal with it ourselves
_this.controller = new AbortController()
const signal = _this.controller.signal
const abortListener = (e:Event) => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromError('User aborted request.')
signal.removeEventListener('abort', abortListener)
}
signal.addEventListener('abort', abortListener)
const fetchOptions = {
method: 'POST',
headers: {
Authorization: `Bearer ${getApiKey()}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(request),
signal
}
if (opts.streaming) {
/**
* Streaming request/response
* We'll get the response a token at a time, as soon as they are ready
*/
if (modelDetail.type === 'PetalsV2Websocket') {
// Petals
const ws = new WebSocket(getEndpoint(model))
const abortListener = (e:Event) => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromError('User aborted request.')
signal.removeEventListener('abort', abortListener)
ws.close()
}
signal.addEventListener('abort', abortListener)
const stopSequences = modelDetail.stop || ['###']
const stopSequencesC = stopSequences.slice()
const stopSequence = stopSequencesC.shift()
chatResponse.onFinish(() => {
_this.updating = false
_this.updatingMessage = ''
})
fetchEventSource(getApiBase() + getEndpointCompletions(), {
...fetchOptions,
openWhenHidden: true,
onmessage (ev) {
// Remove updating indicator
_this.updating = 1 // hide indicator, but still signal we're updating
_this.updatingMessage = ''
// console.log('ev.data', ev.data)
if (!chatResponse.hasFinished()) {
if (ev.data === '[DONE]') {
// ?? anything to do when "[DONE]"?
} else {
const data = JSON.parse(ev.data)
// console.log('data', data)
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
}
ws.onopen = () => {
ws.send(JSON.stringify({
type: 'open_inference_session',
model,
max_length: maxTokens || opts.maxTokens
}))
ws.onmessage = event => {
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error opening socket: ' + response.traceback)
console.error(err)
throw err
}
},
onclose () {
const rMessages = request.messages || [] as Message[]
const inputArray = (rMessages).reduce((a, m) => {
const c = getRoleTag(m.role, model, chatSettings) + m.content
a.push(c)
return a
}, [] as string[])
const lastMessage = rMessages[rMessages.length - 1]
if (lastMessage && lastMessage.role !== 'assistant') {
inputArray.push(getRoleTag('assistant', model, chatSettings))
}
const petalsRequest = {
type: 'generate',
inputs: (request.messages || [] as Message[]).reduce((a, m) => {
const c = getRoleTag(m.role, model, chatSettings) + m.content
a.push(c)
return a
}, [] as string[]).join(stopSequence),
max_new_tokens: 3, // wait for up to 3 tokens before displaying
stop_sequence: stopSequence,
doSample: 1,
temperature: request.temperature || 0,
top_p: request.top_p || 0,
extra_stop_sequences: stopSequencesC
}
ws.send(JSON.stringify(petalsRequest))
ws.onmessage = event => {
// Remove updating indicator
_this.updating = 1 // hide indicator, but still signal we're updating
_this.updatingMessage = ''
const response = JSON.parse(event.data)
if (!response.ok) {
const err = new Error('Error in response: ' + response.traceback)
console.error(err)
throw err
}
window.setTimeout(() => {
chatResponse.updateFromAsyncResponse(
{
model,
choices: [{
delta: {
content: response.outputs,
role: 'assistant'
},
finish_reason: (response.stop ? 'stop' : null)
}]
} as any
)
if (response.stop) {
const message = chatResponse.getMessages()[0]
if (message) {
for (let i = 0, l = stopSequences.length; i < l; i++) {
if (message.content.endsWith(stopSequences[i])) {
message.content = message.content.slice(0, message.content.length - stopSequences[i].length)
updateMessages(chatId)
}
}
}
}
}, 1)
}
}
ws.onclose = () => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromClose()
},
onerror (err) {
}
ws.onerror = err => {
console.error(err)
throw err
},
async onopen (response) {
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
// everything's good
} else {
// client-side errors are usually non-retriable:
await _this.handleError(response)
}
}
}).catch(err => {
}
} else {
// OpenAI
const abortListener = (e:Event) => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromError(err.message)
})
} else {
chatResponse.updateFromError('User aborted request.')
signal.removeEventListener('abort', abortListener)
}
signal.addEventListener('abort', abortListener)
const fetchOptions = {
method: 'POST',
headers: {
Authorization: `Bearer ${getApiKey()}`,
'Content-Type': 'application/json'
},
body: JSON.stringify(request),
signal
}
if (opts.streaming) {
/**
* Streaming request/response
* We'll get the response a token at a time, as soon as they are ready
*/
chatResponse.onFinish(() => {
_this.updating = false
_this.updatingMessage = ''
})
fetchEventSource(getEndpoint(model), {
...fetchOptions,
openWhenHidden: true,
onmessage (ev) {
// Remove updating indicator
_this.updating = 1 // hide indicator, but still signal we're updating
_this.updatingMessage = ''
// console.log('ev.data', ev.data)
if (!chatResponse.hasFinished()) {
if (ev.data === '[DONE]') {
// ?? anything to do when "[DONE]"?
} else {
const data = JSON.parse(ev.data)
// console.log('data', data)
window.setTimeout(() => { chatResponse.updateFromAsyncResponse(data) }, 1)
}
}
},
onclose () {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromClose()
},
onerror (err) {
console.error(err)
throw err
},
async onopen (response) {
if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
// everything's good
} else {
// client-side errors are usually non-retriable:
await _this.handleError(response)
}
}
}).catch(err => {
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromError(err.message)
})
} else {
/**
* Non-streaming request/response
* We'll get the response all at once, after a long delay
*/
const response = await fetch(getApiBase() + getEndpointCompletions(), fetchOptions)
if (!response.ok) {
await _this.handleError(response)
} else {
const json = await response.json()
// Remove updating indicator
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromSyncResponse(json)
const response = await fetch(getEndpoint(model), fetchOptions)
if (!response.ok) {
await _this.handleError(response)
} else {
const json = await response.json()
// Remove updating indicator
_this.updating = false
_this.updatingMessage = ''
chatResponse.updateFromSyncResponse(json)
}
}
}
} catch (e) {
@ -393,11 +502,11 @@ export class ChatRequest {
* Gets an estimate of how many extra tokens will be added that won't be part of the visible messages
* @param filtered
*/
private getTokenCountPadding (filtered: Message[]): number {
private getTokenCountPadding (filtered: Message[], settings: ChatSettings): number {
let result = 0
// add cost of hiddenPromptPrefix
result += this.buildHiddenPromptPrefixMessages(filtered)
.reduce((a, m) => a + countMessageTokens(m, this.getModel()), 0)
.reduce((a, m) => a + countMessageTokens(m, this.getModel(), settings), 0)
// more here eventually?
return result
}
@ -419,10 +528,10 @@ export class ChatRequest {
}
// Get extra counts for when the prompts are finally sent.
const countPadding = this.getTokenCountPadding(filtered)
const countPadding = this.getTokenCountPadding(filtered, chatSettings)
// See if we have enough to apply any of the reduction modes
const fullPromptSize = countPromptTokens(filtered, model) + countPadding
const fullPromptSize = countPromptTokens(filtered, model, chatSettings) + countPadding
if (fullPromptSize < chatSettings.summaryThreshold) return await continueRequest() // nothing to do yet
const overMax = fullPromptSize > maxTokens * 0.95
@ -445,12 +554,12 @@ export class ChatRequest {
* *************************************************************
*/
let promptSize = countPromptTokens(top.concat(rw), model) + countPadding
let promptSize = countPromptTokens(top.concat(rw), model, chatSettings) + countPadding
while (rw.length && rw.length > pinBottom && promptSize >= chatSettings.summaryThreshold) {
const rolled = rw.shift()
// Hide messages we're "rolling"
if (rolled) rolled.suppress = true
promptSize = countPromptTokens(top.concat(rw), model) + countPadding
promptSize = countPromptTokens(top.concat(rw), model, chatSettings) + countPadding
}
// Run a new request, now with the rolled messages hidden
return await _this.sendRequest(get(currentChatMessages), {
@ -466,26 +575,26 @@ export class ChatRequest {
const bottom = rw.slice(0 - pinBottom)
let continueCounter = chatSettings.summaryExtend + 1
rw = rw.slice(0, 0 - pinBottom)
let reductionPoolSize = countPromptTokens(rw, model)
let reductionPoolSize = countPromptTokens(rw, model, chatSettings)
const ss = Math.abs(chatSettings.summarySize)
const getSS = ():number => (ss < 1 && ss > 0)
? Math.round(reductionPoolSize * ss) // If summarySize between 0 and 1, use percentage of reduced
: Math.min(ss, reductionPoolSize * 0.5) // If > 1, use token count
const topSize = countPromptTokens(top, model)
const topSize = countPromptTokens(top, model, chatSettings)
let maxSummaryTokens = getSS()
let promptSummary = prepareSummaryPrompt(chatId, maxSummaryTokens)
const summaryRequest = { role: 'user', content: promptSummary } as Message
let promptSummarySize = countMessageTokens(summaryRequest, model)
let promptSummarySize = countMessageTokens(summaryRequest, model, chatSettings)
// Make sure there is enough room to generate the summary, and try to make sure
// the last prompt is a user prompt as that seems to work better for summaries
while ((topSize + reductionPoolSize + promptSummarySize + maxSummaryTokens) >= maxTokens ||
(reductionPoolSize >= 100 && rw[rw.length - 1]?.role !== 'user')) {
bottom.unshift(rw.pop() as Message)
reductionPoolSize = countPromptTokens(rw, model)
reductionPoolSize = countPromptTokens(rw, model, chatSettings)
maxSummaryTokens = getSS()
promptSummary = prepareSummaryPrompt(chatId, maxSummaryTokens)
summaryRequest.content = promptSummary
promptSummarySize = countMessageTokens(summaryRequest, model)
promptSummarySize = countMessageTokens(summaryRequest, model, chatSettings)
}
if (reductionPoolSize < 50) {
if (overMax) addError(chatId, 'Check summary settings. Unable to summarize enough messages.')
@ -571,10 +680,10 @@ export class ChatRequest {
// Try to get more of it
delete summaryResponse.finish_reason
_this.updatingMessage = 'Summarizing more...'
let _recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model)
let _recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model, chatSettings)
while (rw.length && (_recount + maxSummaryTokens >= maxTokens)) {
rw.shift()
_recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model)
_recount = countPromptTokens(top.concat(rw).concat([summaryRequest]).concat([summaryResponse]), model, chatSettings)
}
loopCount++
continue

View File

@ -3,7 +3,6 @@
import { getChatDefaults, getChatSettingList, getChatSettingObjectByKey, getExcludeFromProfile } from './Settings.svelte'
import {
saveChatStore,
apiKeyStorage,
chatsStorage,
globalStorage,
saveCustomProfile,
@ -13,7 +12,7 @@
checkStateChange,
addChat
} from './Storage.svelte'
import type { Chat, ChatSetting, ResponseModels, SettingSelect, SelectOption, ChatSettings } from './Types.svelte'
import type { Chat, ChatSetting, SettingSelect, ChatSettings } from './Types.svelte'
import { errorNotice, sizeTextElements } from './Util.svelte'
import Fa from 'svelte-fa/src/fa.svelte'
import {
@ -35,8 +34,7 @@
import { replace } from 'svelte-spa-router'
import { openModal } from 'svelte-modals'
import PromptConfirm from './PromptConfirm.svelte'
import { getApiBase, getEndpointModels } from './ApiUtil.svelte'
import { supportedModelKeys } from './Models.svelte'
import { getModelOptions } from './Models.svelte'
export let chatId:number
export const show = () => { showSettings() }
@ -185,30 +183,9 @@
// Refresh settings modal
showSettingsModal++
// Load available models from OpenAI
const allModels = (await (
await fetch(getApiBase() + getEndpointModels(), {
method: 'GET',
headers: {
Authorization: `Bearer ${$apiKeyStorage}`,
'Content-Type': 'application/json'
}
})
).json()) as ResponseModels
const filteredModels = supportedModelKeys.filter((model) => allModels.data.find((m) => m.id === model))
const modelOptions:SelectOption[] = filteredModels.reduce((a, m) => {
const o:SelectOption = {
value: m,
text: m
}
a.push(o)
return a
}, [] as SelectOption[])
// Update the models in the settings
if (modelSetting) {
modelSetting.options = modelOptions
modelSetting.options = await getModelOptions()
}
// Refresh settings modal
showSettingsModal++

View File

@ -1,11 +1,14 @@
<script lang="ts">
import { apiKeyStorage, lastChatId, getChat, started } from './Storage.svelte'
import { apiKeyStorage, globalStorage, lastChatId, getChat, started, setGlobalSettingValueByKey } from './Storage.svelte'
import Footer from './Footer.svelte'
import { replace } from 'svelte-spa-router'
import { onMount } from 'svelte'
import { getPetalsV2Websocket } from './ApiUtil.svelte'
$: apiKey = $apiKeyStorage
let showPetalsSettings = $globalStorage.enablePetals
onMount(() => {
if (!$started) {
$started = true
@ -19,6 +22,12 @@ onMount(() => {
$lastChatId = 0
})
const setPetalsEnabled = (event: Event) => {
const el = (event.target as HTMLInputElement)
setGlobalSettingValueByKey('enablePetals', !!el.checked)
showPetalsSettings = $globalStorage.enablePetals
}
</script>
<section class="section">
@ -60,6 +69,8 @@ onMount(() => {
<p class="control">
<button class="button is-info" type="submit">Save</button>
</p>
</form>
{#if !apiKey}
@ -70,6 +81,66 @@ onMount(() => {
{/if}
</div>
</article>
<article class="message" class:is-info={true}>
<div class="message-body">
<label class="label" for="enablePetals">
<input
type="checkbox"
class="checkbox"
id="enablePetals"
checked={!!$globalStorage.enablePetals}
on:click={setPetalsEnabled}
>
Use Petals API and Models
</label>
{#if showPetalsSettings}
<p>Set Petals API Endpoint:</p>
<form
class="field has-addons has-addons-right"
on:submit|preventDefault={(event) => {
if (event.target && event.target[0].value) {
setGlobalSettingValueByKey('pedalsEndpoint', (event.target[0].value).trim())
} else {
setGlobalSettingValueByKey('pedalsEndpoint', '')
}
}}
>
<p class="control is-expanded">
<input
aria-label="PetalsAPI Endpoint"
type="text"
class="input"
placeholder={getPetalsV2Websocket()}
value={$globalStorage.pedalsEndpoint || ''}
/>
</p>
<p class="control">
<button class="button is-info" type="submit">Save</button>
</p>
</form>
<p>
Only use <u>{getPetalsV2Websocket()}</u> for testing. You must set up your own Petals server for actual use.
</p>
<p>
<b>Do not send sensitive information when using Petals.</b>
</p>
<p>
For more information on Petals, see
<a href="https://github.com/petals-infra/chat.petals.dev">https://github.com/petals-infra/chat.petals.dev</a>
</p>
{/if}
{#if !apiKey}
<p class="help is-danger">
Please enter your <a href="https://platform.openai.com/account/api-keys">OpenAI API key</a> above to use ChatGPT-web.
It is required to use ChatGPT-web.
</p>
{/if}
</div>
</article>
{#if apiKey}
<article class="message is-info">
<div class="message-body">

View File

@ -1,43 +1,63 @@
<script context="module" lang="ts">
import type { ModelDetail, Model } from './Types.svelte'
import { getApiBase, getEndpointCompletions, getEndpointGenerations, getEndpointModels, getPetalsV2Websocket } from './ApiUtil.svelte'
import { apiKeyStorage, globalStorage } from './Storage.svelte'
import { get } from 'svelte/store'
import type { ModelDetail, Model, ResponseModels, SelectOption, ChatSettings } from './Types.svelte'
import { encode } from 'gpt-tokenizer'
import llamaTokenizer from 'llama-tokenizer-js'
// Reference: https://openai.com/pricing#language-models
// Eventually we'll add API hosts and endpoints to this
const modelDetails : Record<string, ModelDetail> = {
'gpt-4-32k': {
type: 'OpenAIChat',
prompt: 0.00006, // $0.06 per 1000 tokens prompt
completion: 0.00012, // $0.12 per 1000 tokens completion
max: 32768 // 32k max token buffer
},
'gpt-4': {
type: 'OpenAIChat',
prompt: 0.00003, // $0.03 per 1000 tokens prompt
completion: 0.00006, // $0.06 per 1000 tokens completion
max: 8192 // 8k max token buffer
},
'gpt-3.5': {
type: 'OpenAIChat',
prompt: 0.0000015, // $0.0015 per 1000 tokens prompt
completion: 0.000002, // $0.002 per 1000 tokens completion
max: 4096 // 4k max token buffer
},
'gpt-3.5-turbo-16k': {
type: 'OpenAIChat',
prompt: 0.000003, // $0.003 per 1000 tokens prompt
completion: 0.000004, // $0.004 per 1000 tokens completion
max: 16384 // 16k max token buffer
},
'meta-llama/Llama-2-70b-chat-hf': {
type: 'PetalsV2Websocket',
label: 'Petals - Llama-2-70b-chat',
stop: ['###', '</s>'],
prompt: 0.000000, // $0.000 per 1000 tokens prompt
completion: 0.000000, // $0.000 per 1000 tokens completion
max: 4096 // 4k max token buffer
}
}
const imageModels : Record<string, ModelDetail> = {
export const imageModels : Record<string, ModelDetail> = {
'dall-e-1024x1024': {
type: 'OpenAIDall-e',
prompt: 0.00,
completion: 0.020, // $0.020 per image
max: 1000 // 1000 char prompt, max
},
'dall-e-512x512': {
type: 'OpenAIDall-e',
prompt: 0.00,
completion: 0.018, // $0.018 per image
max: 1000 // 1000 char prompt, max
},
'dall-e-256x256': {
type: 'OpenAIDall-e',
prompt: 0.00,
completion: 0.016, // $0.016 per image
max: 1000 // 1000 char prompt, max
@ -47,8 +67,9 @@ const imageModels : Record<string, ModelDetail> = {
const unknownDetail = {
prompt: 0,
completion: 0,
max: 4096
}
max: 4096,
type: 'OpenAIChat'
} as ModelDetail
// See: https://platform.openai.com/docs/models/model-endpoint-compatibility
// Eventually we'll add UI for managing this
@ -62,7 +83,8 @@ export const supportedModels : Record<string, ModelDetail> = {
'gpt-3.5-turbo': modelDetails['gpt-3.5'],
'gpt-3.5-turbo-16k': modelDetails['gpt-3.5-turbo-16k'],
'gpt-3.5-turbo-0301': modelDetails['gpt-3.5'],
'gpt-3.5-turbo-0613': modelDetails['gpt-3.5']
'gpt-3.5-turbo-0613': modelDetails['gpt-3.5'],
'meta-llama/Llama-2-70b-chat-hf': modelDetails['meta-llama/Llama-2-70b-chat-hf']
}
const lookupList = {
@ -75,7 +97,7 @@ export const supportedModelKeys = Object.keys({ ...supportedModels, ...imageMode
const tpCache : Record<string, ModelDetail> = {}
export const getModelDetail = (model: Model) => {
export const getModelDetail = (model: Model): ModelDetail => {
// First try to get exact match, then from cache
let r = supportedModels[model] || tpCache[model]
if (r) return r
@ -93,4 +115,93 @@ export const getModelDetail = (model: Model) => {
return r
}
export const getEndpoint = (model: Model): string => {
const modelDetails = getModelDetail(model)
const gSettings = get(globalStorage)
switch (modelDetails.type) {
case 'PetalsV2Websocket':
return gSettings.pedalsEndpoint || getPetalsV2Websocket()
case 'OpenAIDall-e':
return getApiBase() + getEndpointGenerations()
case 'OpenAIChat':
default:
return gSettings.openAICompletionEndpoint || (getApiBase() + getEndpointCompletions())
}
}
export const getRoleTag = (role: string, model: Model, settings: ChatSettings): string => {
const modelDetails = getModelDetail(model)
switch (modelDetails.type) {
case 'PetalsV2Websocket':
if (role === 'assistant') {
return ('Assistant') +
': '
}
if (role === 'user') return 'Human: '
return ''
case 'OpenAIDall-e':
return role
case 'OpenAIChat':
default:
return role
}
}
export const getTokens = (model: Model, value: string): number[] => {
const modelDetails = getModelDetail(model)
switch (modelDetails.type) {
case 'PetalsV2Websocket':
return llamaTokenizer.encode(value)
case 'OpenAIDall-e':
return [0]
case 'OpenAIChat':
default:
return encode(value)
}
}
export const countTokens = (model: Model, value: string): number => {
return getTokens(model, value).length
}
export async function getModelOptions (): Promise<SelectOption[]> {
const gSettings = get(globalStorage)
const openAiKey = get(apiKeyStorage)
// Load available models from OpenAI
let openAiModels
try {
openAiModels = (await (
await fetch(getApiBase() + getEndpointModels(), {
method: 'GET',
headers: {
Authorization: `Bearer ${openAiKey}`,
'Content-Type': 'application/json'
}
})
).json()) as ResponseModels
} catch (e) {
openAiModels = { data: [] }
}
const filteredModels = supportedModelKeys.filter((model) => {
switch (getModelDetail(model).type) {
case 'PetalsV2Websocket':
return gSettings.enablePetals
case 'OpenAIChat':
default:
return openAiModels.data.find((m) => m.id === model)
}
})
const modelOptions:SelectOption[] = filteredModels.reduce((a, m) => {
const o:SelectOption = {
value: m,
text: m
}
a.push(o)
return a
}, [] as SelectOption[])
return modelOptions
}
</script>

View File

@ -1,7 +1,6 @@
<script context="module" lang="ts">
import { applyProfile } from './Profiles.svelte'
import { getChatSettings, getGlobalSettings, setGlobalSettingValueByKey } from './Storage.svelte'
import { encode } from 'gpt-tokenizer'
import { faArrowDown91, faArrowDownAZ, faCheck, faThumbTack } from '@fortawesome/free-solid-svg-icons/index'
// Setting definitions
@ -18,6 +17,7 @@ import {
type ChatSortOption
} from './Types.svelte'
import { getTokens } from './Models.svelte'
export const defaultModel:Model = 'gpt-3.5-turbo'
@ -104,7 +104,10 @@ export const globalDefaults: GlobalSettings = {
lastProfile: 'default',
defaultProfile: 'default',
hideSummarized: false,
chatSort: 'created'
chatSort: 'created',
openAICompletionEndpoint: '',
enablePetals: false,
pedalsEndpoint: ''
}
const excludeFromProfile = {
@ -497,7 +500,7 @@ const chatSettingsList: ChatSetting[] = [
// console.log('logit_bias', val, getChatSettings(chatId).logit_bias)
if (!val) return null
const tokenized:Record<number, number> = Object.entries(val).reduce((a, [k, v]) => {
const tokens:number[] = encode(k)
const tokens:number[] = getTokens(getChatSettings(chatId).model, k)
tokens.forEach(t => { a[t] = v })
return a
}, {} as Record<number, number>)
@ -536,6 +539,21 @@ const globalSettingsList:GlobalSetting[] = [
key: 'hideSummarized',
name: 'Hide Summarized Messages',
type: 'boolean'
},
{
key: 'openAICompletionEndpoint',
name: 'OpenAI Completions Endpoint',
type: 'text'
},
{
key: 'enablePetals',
name: 'Enable Petals APIs',
type: 'boolean'
},
{
key: 'pedalsEndpoint',
name: 'Petals API Endpoint',
type: 'text'
}
]

View File

@ -1,25 +1,49 @@
<script context="module" lang="ts">
import { getModelDetail } from './Models.svelte'
import type { Message, Model, Usage } from './Types.svelte'
import { encode } from 'gpt-tokenizer'
import { countTokens, getModelDetail, getRoleTag } from './Models.svelte'
import type { ChatSettings, Message, Model, Usage } from './Types.svelte'
export const getPrice = (tokens: Usage, model: Model): number => {
const t = getModelDetail(model)
return ((tokens.prompt_tokens * t.prompt) + (tokens.completion_tokens * t.completion))
}
export const countPromptTokens = (prompts:Message[], model:Model):number => {
return prompts.reduce((a, m) => {
a += countMessageTokens(m, model)
export const countPromptTokens = (prompts:Message[], model:Model, settings: ChatSettings):number => {
const detail = getModelDetail(model)
const count = prompts.reduce((a, m) => {
switch (detail.type) {
case 'PetalsV2Websocket':
a += countMessageTokens(m, model, settings)
break
case 'OpenAIChat':
default:
a += countMessageTokens(m, model, settings)
}
return a
}, 0) + 3 // Always seems to be message counts + 3
}, 0)
switch (detail.type) {
case 'PetalsV2Websocket':
return count + (Math.max(prompts.length - 1, 0) * countTokens(model, (detail.stop && detail.stop[0]) || '###')) // todo, make stop per model?
case 'OpenAIChat':
default:
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
return count + 3 // Always seems to be message counts + 3
}
}
export const countMessageTokens = (message:Message, model:Model):number => {
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
return encode('## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n').length
export const countMessageTokens = (message:Message, model:Model, settings: ChatSettings):number => {
const detail = getModelDetail(model)
switch (detail.type) {
case 'PetalsV2Websocket':
return countTokens(model, getRoleTag(message.role, model, settings) + ': ' + message.content)
case 'OpenAIChat':
default:
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
return countTokens(model, '## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n')
}
}
export const getModelMaxTokens = (model:Model):number => {

View File

@ -7,7 +7,12 @@ export type Model = typeof supportedModelKeys[number];
export type ImageGenerationSizes = typeof imageGenerationSizeTypes[number];
export type RequestType = 'OpenAIChat' | 'OpenAIDall-e' | 'PetalsV2Websocket'
export type ModelDetail = {
type: RequestType;
label?: string;
stop?: string[];
prompt: number;
completion: number;
max: number;
@ -122,16 +127,16 @@ export type Chat = {
};
type ResponseOK = {
id: string;
object: string;
created: number;
choices: {
index: number;
id?: string;
object?: string;
created?: number;
choices?: {
index?: number;
message: Message;
finish_reason: string;
finish_reason?: string;
delta: Message;
}[];
usage: Usage;
usage?: Usage;
model: Model;
};
@ -172,6 +177,9 @@ export type GlobalSettings = {
defaultProfile: string;
hideSummarized: boolean;
chatSort: ChatSortOptions;
openAICompletionEndpoint: string;
enablePetals: boolean;
pedalsEndpoint: string;
};
type SettingNumber = {