Get temp and top_p working for Petals
This commit is contained in:
parent
6d35a46d50
commit
15dcd27e8f
|
@ -53,6 +53,14 @@ export const runPetalsCompletionRequest = async (
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
const rMessages = request.messages || [] as Message[]
|
const rMessages = request.messages || [] as Message[]
|
||||||
|
// make sure top_p and temperature are set the way we need
|
||||||
|
let temperature = request.temperature || 0
|
||||||
|
if (isNaN(temperature as any) || temperature === 1) temperature = 1
|
||||||
|
if (temperature === 0) temperature = 0.0001
|
||||||
|
let topP = request.top_p
|
||||||
|
if (isNaN(topP as any) || topP === 1) topP = 1
|
||||||
|
if (topP === 0) topP = 0.0001
|
||||||
|
// build the message array
|
||||||
const inputArray = (rMessages).reduce((a, m) => {
|
const inputArray = (rMessages).reduce((a, m) => {
|
||||||
const c = getRoleTag(m.role, model, chatRequest.chat) + m.content
|
const c = getRoleTag(m.role, model, chatRequest.chat) + m.content
|
||||||
a.push(c)
|
a.push(c)
|
||||||
|
@ -65,11 +73,11 @@ export const runPetalsCompletionRequest = async (
|
||||||
const petalsRequest = {
|
const petalsRequest = {
|
||||||
type: 'generate',
|
type: 'generate',
|
||||||
inputs: inputArray.join(stopSequence),
|
inputs: inputArray.join(stopSequence),
|
||||||
max_new_tokens: 3, // wait for up to 3 tokens before displaying
|
max_new_tokens: 1, // wait for up to 1 tokens before displaying
|
||||||
stop_sequence: stopSequence,
|
stop_sequence: stopSequence,
|
||||||
doSample: 1,
|
do_sample: 1, // enable top p and the like
|
||||||
temperature: request.temperature || 0,
|
temperature,
|
||||||
top_p: request.top_p || 0,
|
top_p: topP,
|
||||||
extra_stop_sequences: stopSequencesC
|
extra_stop_sequences: stopSequencesC
|
||||||
}
|
}
|
||||||
ws.send(JSON.stringify(petalsRequest))
|
ws.send(JSON.stringify(petalsRequest))
|
||||||
|
|
|
@ -55,6 +55,14 @@ export const getExcludeFromProfile = () => {
|
||||||
return excludeFromProfile
|
return excludeFromProfile
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const isNotOpenAI = (chatId) => {
|
||||||
|
return getModelDetail(getChatSettings(chatId).model).type !== 'OpenAIChat'
|
||||||
|
}
|
||||||
|
|
||||||
|
const isNotPetals = (chatId) => {
|
||||||
|
return getModelDetail(getChatSettings(chatId).model).type !== 'Petals'
|
||||||
|
}
|
||||||
|
|
||||||
const gptDefaults = {
|
const gptDefaults = {
|
||||||
model: defaultModel,
|
model: defaultModel,
|
||||||
messages: [],
|
messages: [],
|
||||||
|
@ -406,7 +414,13 @@ const modelSetting: ChatSetting & SettingSelect = {
|
||||||
key: 'model',
|
key: 'model',
|
||||||
name: 'Model',
|
name: 'Model',
|
||||||
title: 'The model to use - GPT-3.5 is cheaper, but GPT-4 is more powerful.',
|
title: 'The model to use - GPT-3.5 is cheaper, but GPT-4 is more powerful.',
|
||||||
header: 'Below are the settings that OpenAI allows to be changed for the API calls. See the <a target="_blank" href="https://platform.openai.com/docs/api-reference/chat/create">OpenAI API docs</a> for more details.',
|
header: (chatId) => {
|
||||||
|
if (isNotOpenAI(chatId)) {
|
||||||
|
return 'Below are the settings that can be changed for the API calls. See <a target="_blank" href="https://platform.openai.com/docs/api-reference/chat/create">this overview</a> to start, though not all settings translate to Petals.'
|
||||||
|
} else {
|
||||||
|
return 'Below are the settings that OpenAI allows to be changed for the API calls. See the <a target="_blank" href="https://platform.openai.com/docs/api-reference/chat/create">OpenAI API docs</a> for more details.'
|
||||||
|
}
|
||||||
|
},
|
||||||
headerClass: 'is-warning',
|
headerClass: 'is-warning',
|
||||||
options: [],
|
options: [],
|
||||||
type: 'select',
|
type: 'select',
|
||||||
|
@ -414,14 +428,6 @@ const modelSetting: ChatSetting & SettingSelect = {
|
||||||
afterChange: (chatId, setting) => true // refresh settings
|
afterChange: (chatId, setting) => true // refresh settings
|
||||||
}
|
}
|
||||||
|
|
||||||
const isNotOpenAI = (chatId) => {
|
|
||||||
return getModelDetail(getChatSettings(chatId).model).type !== 'OpenAIChat'
|
|
||||||
}
|
|
||||||
|
|
||||||
const isNotPetals = (chatId) => {
|
|
||||||
return getModelDetail(getChatSettings(chatId).model).type !== 'Petals'
|
|
||||||
}
|
|
||||||
|
|
||||||
const chatSettingsList: ChatSetting[] = [
|
const chatSettingsList: ChatSetting[] = [
|
||||||
profileSetting,
|
profileSetting,
|
||||||
...systemPromptSettings,
|
...systemPromptSettings,
|
||||||
|
@ -448,7 +454,7 @@ const chatSettingsList: ChatSetting[] = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
key: 'top_p',
|
key: 'top_p',
|
||||||
name: 'Nucleus Sampling',
|
name: 'Nucleus Sampling (Top-p)',
|
||||||
title: 'An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n' +
|
title: 'An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n' +
|
||||||
'\n' +
|
'\n' +
|
||||||
'We generally recommend altering this or temperature but not both',
|
'We generally recommend altering this or temperature but not both',
|
||||||
|
|
|
@ -259,8 +259,8 @@ export type ChatSetting = {
|
||||||
title: string;
|
title: string;
|
||||||
forceApi?: boolean; // force in api requests, even if set to default
|
forceApi?: boolean; // force in api requests, even if set to default
|
||||||
hidden?: boolean; // Hide from setting menus
|
hidden?: boolean; // Hide from setting menus
|
||||||
header?: string;
|
header?: string | ValueFn;
|
||||||
headerClass?: string;
|
headerClass?: string | ValueFn;
|
||||||
placeholder?: string | ValueFn;
|
placeholder?: string | ValueFn;
|
||||||
hide?: (chatId:number) => boolean;
|
hide?: (chatId:number) => boolean;
|
||||||
apiTransform?: (chatId:number, setting:ChatSetting, value:any) => any;
|
apiTransform?: (chatId:number, setting:ChatSetting, value:any) => any;
|
||||||
|
|
Loading…
Reference in New Issue