Move token counting to model detail.

This commit is contained in:
Webifi 2023-08-15 21:46:33 -05:00
parent 91885384a1
commit a08d8bcd54
4 changed files with 38 additions and 35 deletions

View File

@ -1,44 +1,18 @@
<script context="module" lang="ts"> <script context="module" lang="ts">
import { countTokens, getDeliminator, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence } from './Models.svelte' import { getModelDetail } from './Models.svelte'
import type { Chat, Message, Model, Usage } from './Types.svelte' import type { Chat, Message, Model, Usage } from './Types.svelte'
export const getPrice = (tokens: Usage, model: Model): number => { export const getPrice = (tokens: Usage, model: Model): number => {
const t = getModelDetail(model) const t = getModelDetail(model)
return ((tokens.prompt_tokens * t.prompt) + (tokens.completion_tokens * t.completion)) return ((tokens.prompt_tokens * (t.prompt || 0)) + (tokens.completion_tokens * (t.completion || 0)))
} }
export const countPromptTokens = (prompts:Message[], model:Model, chat: Chat):number => { export const countPromptTokens = (prompts:Message[], model:Model, chat: Chat):number => {
const detail = getModelDetail(model) return getModelDetail(model).countPromptTokens(prompts, model, chat)
const count = prompts.reduce((a, m) => {
a += countMessageTokens(m, model, chat)
return a
}, 0)
switch (detail.type) {
case 'Petals':
return count + countTokens(model, getStartSequence(chat)) + countTokens(model, getLeadPrompt(chat))
case 'OpenAIChat':
default:
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
return count + 3 // Always seems to be message counts + 3
}
} }
export const countMessageTokens = (message:Message, model:Model, chat: Chat):number => { export const countMessageTokens = (message:Message, model:Model, chat: Chat):number => {
const detail = getModelDetail(model) return getModelDetail(model).countMessageTokens(message, model, chat)
const delim = getDeliminator(chat)
switch (detail.type) {
case 'Petals':
return countTokens(model, getRoleTag(message.role, model, chat) + ': ' +
message.content + getRoleEnd(message.role, model, chat) + (delim || '###'))
case 'OpenAIChat':
default:
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
return countTokens(model, '## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n')
}
} }
export const getModelMaxTokens = (model:Model):number => { export const getModelMaxTokens = (model:Model):number => {

View File

@ -281,13 +281,15 @@ export type ModelDetail = {
leadPrompt?: string, leadPrompt?: string,
prompt?: number; prompt?: number;
completion?: number; completion?: number;
max?: number; max: number;
opt?: Record<string, any>; opt?: Record<string, any>;
preFillMerge?: (existingContent:string, newContent:string)=>string; preFillMerge?: (existingContent:string, newContent:string)=>string;
enabled?: boolean; enabled?: boolean;
hide?: boolean; hide?: boolean;
check: (modelDetail: ModelDetail) => Promise<void>; check: (modelDetail: ModelDetail) => Promise<void>;
getTokens: (val: string) => number[]; getTokens: (val: string) => number[];
countPromptTokens: (prompts:Message[], model:Model, chat: Chat) => number;
countMessageTokens: (message:Message, model:Model, chat: Chat) => number;
getEndpoint: (model: Model) => string; getEndpoint: (model: Model) => string;
help: string; help: string;
hideSetting: (chatId: number, setting: ChatSetting) => boolean; hideSetting: (chatId: number, setting: ChatSetting) => boolean;

View File

@ -1,7 +1,9 @@
<script context="module" lang="ts"> <script context="module" lang="ts">
import { getApiBase, getEndpointCompletions, getEndpointGenerations } from '../../ApiUtil.svelte' import { getApiBase, getEndpointCompletions, getEndpointGenerations } from '../../ApiUtil.svelte'
import { countTokens } from '../../Models.svelte'
import { countMessageTokens } from '../../Stats.svelte'
import { globalStorage } from '../../Storage.svelte' import { globalStorage } from '../../Storage.svelte'
import type { ModelDetail } from '../../Types.svelte' import type { Chat, Message, Model, ModelDetail } from '../../Types.svelte'
import { chatRequest, imageRequest } from './request.svelte' import { chatRequest, imageRequest } from './request.svelte'
import { checkModel } from './util.svelte' import { checkModel } from './util.svelte'
import { encode } from 'gpt-tokenizer' import { encode } from 'gpt-tokenizer'
@ -38,7 +40,19 @@ const chatModelBase = {
check: checkModel, check: checkModel,
getTokens: (value) => encode(value), getTokens: (value) => encode(value),
getEndpoint: (model) => get(globalStorage).openAICompletionEndpoint || (getApiBase() + getEndpointCompletions()), getEndpoint: (model) => get(globalStorage).openAICompletionEndpoint || (getApiBase() + getEndpointCompletions()),
hideSetting: (chatId, setting) => !!hiddenSettings[setting.key] hideSetting: (chatId, setting) => !!hiddenSettings[setting.key],
countMessageTokens: (message:Message, model:Model, chat: Chat) => {
return countTokens(model, '## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n')
},
countPromptTokens: (prompts:Message[], model:Model, chat: Chat):number => {
// Not sure how OpenAI formats it, but this seems to get close to the right counts.
// Would be nice to know. This works for gpt-3.5. gpt-4 could be different.
// Complete stab in the dark here -- update if you know where all the extra tokens really come from.
return prompts.reduce((a, m) => {
a += countMessageTokens(m, model, chat)
return a
}, 0) + 3 // Always seems to be message counts + 3
}
} as ModelDetail } as ModelDetail
// Reference: https://openai.com/pricing#language-models // Reference: https://openai.com/pricing#language-models

View File

@ -1,7 +1,9 @@
<script context="module" lang="ts"> <script context="module" lang="ts">
import { getPetalsBase, getPetalsWebsocket } from '../../ApiUtil.svelte' import { getPetalsBase, getPetalsWebsocket } from '../../ApiUtil.svelte'
import { countTokens, getDeliminator, getLeadPrompt, getRoleEnd, getRoleTag, getStartSequence } from '../../Models.svelte'
import { countMessageTokens } from '../../Stats.svelte'
import { globalStorage } from '../../Storage.svelte' import { globalStorage } from '../../Storage.svelte'
import type { ModelDetail } from '../../Types.svelte' import type { Chat, Message, Model, ModelDetail } from '../../Types.svelte'
import { chatRequest } from './request.svelte' import { chatRequest } from './request.svelte'
import { checkModel } from './util.svelte' import { checkModel } from './util.svelte'
import llamaTokenizer from 'llama-tokenizer-js' import llamaTokenizer from 'llama-tokenizer-js'
@ -33,7 +35,18 @@ const chatModelBase = {
request: chatRequest, request: chatRequest,
getEndpoint: (model) => get(globalStorage).pedalsEndpoint || (getPetalsBase() + getPetalsWebsocket()), getEndpoint: (model) => get(globalStorage).pedalsEndpoint || (getPetalsBase() + getPetalsWebsocket()),
getTokens: (value) => llamaTokenizer.encode(value), getTokens: (value) => llamaTokenizer.encode(value),
hideSetting: (chatId, setting) => !!hideSettings[setting.key] hideSetting: (chatId, setting) => !!hideSettings[setting.key],
countMessageTokens: (message:Message, model:Model, chat: Chat):number => {
const delim = getDeliminator(chat)
return countTokens(model, getRoleTag(message.role, model, chat) + ': ' +
message.content + getRoleEnd(message.role, model, chat) + (delim || '###'))
},
countPromptTokens: (prompts:Message[], model:Model, chat: Chat):number => {
return prompts.reduce((a, m) => {
a += countMessageTokens(m, model, chat)
return a
}, 0) + countTokens(model, getStartSequence(chat)) + countTokens(model, getLeadPrompt(chat))
}
} as ModelDetail } as ModelDetail
export const chatModels : Record<string, ModelDetail> = { export const chatModels : Record<string, ModelDetail> = {