Move token counting to model detail.
This commit is contained in:
		
							parent
							
								
									91885384a1
								
							
						
					
					
						commit
						a08d8bcd54
					
				| 
						 | 
					@ -1,44 +1,18 @@
 | 
				
			||||||
<script context="module" lang="ts">
 | 
					<script context="module" lang="ts">
 | 
				
			||||||
  import { countTokens, getDeliminator, getLeadPrompt, getModelDetail, getRoleEnd, getRoleTag, getStartSequence } from './Models.svelte'
 | 
					  import { getModelDetail } from './Models.svelte'
 | 
				
			||||||
  import type { Chat, Message, Model, Usage } from './Types.svelte'
 | 
					  import type { Chat, Message, Model, Usage } from './Types.svelte'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  export const getPrice = (tokens: Usage, model: Model): number => {
 | 
					  export const getPrice = (tokens: Usage, model: Model): number => {
 | 
				
			||||||
    const t = getModelDetail(model)
 | 
					    const t = getModelDetail(model)
 | 
				
			||||||
    return ((tokens.prompt_tokens * t.prompt) + (tokens.completion_tokens * t.completion))
 | 
					    return ((tokens.prompt_tokens * (t.prompt || 0)) + (tokens.completion_tokens * (t.completion || 0)))
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  export const countPromptTokens = (prompts:Message[], model:Model, chat: Chat):number => {
 | 
					  export const countPromptTokens = (prompts:Message[], model:Model, chat: Chat):number => {
 | 
				
			||||||
    const detail = getModelDetail(model)
 | 
					    return getModelDetail(model).countPromptTokens(prompts, model, chat)
 | 
				
			||||||
    const count = prompts.reduce((a, m) => {
 | 
					 | 
				
			||||||
      a += countMessageTokens(m, model, chat)
 | 
					 | 
				
			||||||
      return a
 | 
					 | 
				
			||||||
    }, 0)
 | 
					 | 
				
			||||||
    switch (detail.type) {
 | 
					 | 
				
			||||||
      case 'Petals':
 | 
					 | 
				
			||||||
        return count + countTokens(model, getStartSequence(chat)) + countTokens(model, getLeadPrompt(chat))
 | 
					 | 
				
			||||||
      case 'OpenAIChat':
 | 
					 | 
				
			||||||
      default:
 | 
					 | 
				
			||||||
        // Not sure how OpenAI formats it, but this seems to get close to the right counts.
 | 
					 | 
				
			||||||
        // Would be nice to know. This works for gpt-3.5.  gpt-4 could be different.
 | 
					 | 
				
			||||||
        // Complete stab in the dark here -- update if you know where all the extra tokens really come from.
 | 
					 | 
				
			||||||
        return count + 3 // Always seems to be message counts + 3
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  export const countMessageTokens = (message:Message, model:Model, chat: Chat):number => {
 | 
					  export const countMessageTokens = (message:Message, model:Model, chat: Chat):number => {
 | 
				
			||||||
    const detail = getModelDetail(model)
 | 
					    return getModelDetail(model).countMessageTokens(message, model, chat)
 | 
				
			||||||
    const delim = getDeliminator(chat)
 | 
					 | 
				
			||||||
    switch (detail.type) {
 | 
					 | 
				
			||||||
      case 'Petals':
 | 
					 | 
				
			||||||
        return countTokens(model, getRoleTag(message.role, model, chat) + ': ' +
 | 
					 | 
				
			||||||
        message.content + getRoleEnd(message.role, model, chat) + (delim || '###'))
 | 
					 | 
				
			||||||
      case 'OpenAIChat':
 | 
					 | 
				
			||||||
      default:
 | 
					 | 
				
			||||||
        // Not sure how OpenAI formats it, but this seems to get close to the right counts.
 | 
					 | 
				
			||||||
        // Would be nice to know. This works for gpt-3.5.  gpt-4 could be different.
 | 
					 | 
				
			||||||
        // Complete stab in the dark here -- update if you know where all the extra tokens really come from.
 | 
					 | 
				
			||||||
        return countTokens(model, '## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n')
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  export const getModelMaxTokens = (model:Model):number => {
 | 
					  export const getModelMaxTokens = (model:Model):number => {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -281,13 +281,15 @@ export type ModelDetail = {
 | 
				
			||||||
    leadPrompt?: string,
 | 
					    leadPrompt?: string,
 | 
				
			||||||
    prompt?: number;
 | 
					    prompt?: number;
 | 
				
			||||||
    completion?: number;
 | 
					    completion?: number;
 | 
				
			||||||
    max?: number;
 | 
					    max: number;
 | 
				
			||||||
    opt?: Record<string, any>;
 | 
					    opt?: Record<string, any>;
 | 
				
			||||||
    preFillMerge?: (existingContent:string, newContent:string)=>string;
 | 
					    preFillMerge?: (existingContent:string, newContent:string)=>string;
 | 
				
			||||||
    enabled?: boolean;
 | 
					    enabled?: boolean;
 | 
				
			||||||
    hide?: boolean;
 | 
					    hide?: boolean;
 | 
				
			||||||
    check: (modelDetail: ModelDetail) => Promise<void>;
 | 
					    check: (modelDetail: ModelDetail) => Promise<void>;
 | 
				
			||||||
    getTokens: (val: string) => number[];
 | 
					    getTokens: (val: string) => number[];
 | 
				
			||||||
 | 
					    countPromptTokens: (prompts:Message[], model:Model, chat: Chat) => number;
 | 
				
			||||||
 | 
					    countMessageTokens: (message:Message, model:Model, chat: Chat) => number;
 | 
				
			||||||
    getEndpoint: (model: Model) => string;
 | 
					    getEndpoint: (model: Model) => string;
 | 
				
			||||||
    help: string;
 | 
					    help: string;
 | 
				
			||||||
    hideSetting: (chatId: number, setting: ChatSetting) => boolean;
 | 
					    hideSetting: (chatId: number, setting: ChatSetting) => boolean;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,9 @@
 | 
				
			||||||
<script context="module" lang="ts">
 | 
					<script context="module" lang="ts">
 | 
				
			||||||
    import { getApiBase, getEndpointCompletions, getEndpointGenerations } from '../../ApiUtil.svelte'
 | 
					    import { getApiBase, getEndpointCompletions, getEndpointGenerations } from '../../ApiUtil.svelte'
 | 
				
			||||||
 | 
					    import { countTokens } from '../../Models.svelte'
 | 
				
			||||||
 | 
					    import { countMessageTokens } from '../../Stats.svelte'
 | 
				
			||||||
    import { globalStorage } from '../../Storage.svelte'
 | 
					    import { globalStorage } from '../../Storage.svelte'
 | 
				
			||||||
    import type { ModelDetail } from '../../Types.svelte'
 | 
					    import type { Chat, Message, Model, ModelDetail } from '../../Types.svelte'
 | 
				
			||||||
    import { chatRequest, imageRequest } from './request.svelte'
 | 
					    import { chatRequest, imageRequest } from './request.svelte'
 | 
				
			||||||
    import { checkModel } from './util.svelte'
 | 
					    import { checkModel } from './util.svelte'
 | 
				
			||||||
    import { encode } from 'gpt-tokenizer'
 | 
					    import { encode } from 'gpt-tokenizer'
 | 
				
			||||||
| 
						 | 
					@ -38,7 +40,19 @@ const chatModelBase = {
 | 
				
			||||||
  check: checkModel,
 | 
					  check: checkModel,
 | 
				
			||||||
  getTokens: (value) => encode(value),
 | 
					  getTokens: (value) => encode(value),
 | 
				
			||||||
  getEndpoint: (model) => get(globalStorage).openAICompletionEndpoint || (getApiBase() + getEndpointCompletions()),
 | 
					  getEndpoint: (model) => get(globalStorage).openAICompletionEndpoint || (getApiBase() + getEndpointCompletions()),
 | 
				
			||||||
  hideSetting: (chatId, setting) => !!hiddenSettings[setting.key]
 | 
					  hideSetting: (chatId, setting) => !!hiddenSettings[setting.key],
 | 
				
			||||||
 | 
					  countMessageTokens: (message:Message, model:Model, chat: Chat) => {
 | 
				
			||||||
 | 
					        return countTokens(model, '## ' + message.role + ' ##:\r\n\r\n' + message.content + '\r\n\r\n\r\n')
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  countPromptTokens: (prompts:Message[], model:Model, chat: Chat):number => {
 | 
				
			||||||
 | 
					        // Not sure how OpenAI formats it, but this seems to get close to the right counts.
 | 
				
			||||||
 | 
					        // Would be nice to know. This works for gpt-3.5.  gpt-4 could be different.
 | 
				
			||||||
 | 
					        // Complete stab in the dark here -- update if you know where all the extra tokens really come from.
 | 
				
			||||||
 | 
					        return prompts.reduce((a, m) => {
 | 
				
			||||||
 | 
					          a += countMessageTokens(m, model, chat)
 | 
				
			||||||
 | 
					          return a
 | 
				
			||||||
 | 
					        }, 0) + 3 // Always seems to be message counts + 3
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
} as ModelDetail
 | 
					} as ModelDetail
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Reference: https://openai.com/pricing#language-models
 | 
					// Reference: https://openai.com/pricing#language-models
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,9 @@
 | 
				
			||||||
<script context="module" lang="ts">
 | 
					<script context="module" lang="ts">
 | 
				
			||||||
    import { getPetalsBase, getPetalsWebsocket } from '../../ApiUtil.svelte'
 | 
					    import { getPetalsBase, getPetalsWebsocket } from '../../ApiUtil.svelte'
 | 
				
			||||||
 | 
					    import { countTokens, getDeliminator, getLeadPrompt, getRoleEnd, getRoleTag, getStartSequence } from '../../Models.svelte'
 | 
				
			||||||
 | 
					    import { countMessageTokens } from '../../Stats.svelte'
 | 
				
			||||||
    import { globalStorage } from '../../Storage.svelte'
 | 
					    import { globalStorage } from '../../Storage.svelte'
 | 
				
			||||||
    import type { ModelDetail } from '../../Types.svelte'
 | 
					    import type { Chat, Message, Model, ModelDetail } from '../../Types.svelte'
 | 
				
			||||||
    import { chatRequest } from './request.svelte'
 | 
					    import { chatRequest } from './request.svelte'
 | 
				
			||||||
    import { checkModel } from './util.svelte'
 | 
					    import { checkModel } from './util.svelte'
 | 
				
			||||||
    import llamaTokenizer from 'llama-tokenizer-js'
 | 
					    import llamaTokenizer from 'llama-tokenizer-js'
 | 
				
			||||||
| 
						 | 
					@ -33,7 +35,18 @@ const chatModelBase = {
 | 
				
			||||||
  request: chatRequest,
 | 
					  request: chatRequest,
 | 
				
			||||||
  getEndpoint: (model) => get(globalStorage).pedalsEndpoint || (getPetalsBase() + getPetalsWebsocket()),
 | 
					  getEndpoint: (model) => get(globalStorage).pedalsEndpoint || (getPetalsBase() + getPetalsWebsocket()),
 | 
				
			||||||
  getTokens: (value) => llamaTokenizer.encode(value),
 | 
					  getTokens: (value) => llamaTokenizer.encode(value),
 | 
				
			||||||
  hideSetting: (chatId, setting) => !!hideSettings[setting.key]
 | 
					  hideSetting: (chatId, setting) => !!hideSettings[setting.key],
 | 
				
			||||||
 | 
					  countMessageTokens: (message:Message, model:Model, chat: Chat):number => {
 | 
				
			||||||
 | 
					        const delim = getDeliminator(chat)
 | 
				
			||||||
 | 
					        return countTokens(model, getRoleTag(message.role, model, chat) + ': ' +
 | 
				
			||||||
 | 
					        message.content + getRoleEnd(message.role, model, chat) + (delim || '###'))
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  countPromptTokens: (prompts:Message[], model:Model, chat: Chat):number => {
 | 
				
			||||||
 | 
					        return prompts.reduce((a, m) => {
 | 
				
			||||||
 | 
					          a += countMessageTokens(m, model, chat)
 | 
				
			||||||
 | 
					          return a
 | 
				
			||||||
 | 
					        }, 0) + countTokens(model, getStartSequence(chat)) + countTokens(model, getLeadPrompt(chat))
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
} as ModelDetail
 | 
					} as ModelDetail
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export const chatModels : Record<string, ModelDetail> = {
 | 
					export const chatModels : Record<string, ModelDetail> = {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue