Skip to main content

How to create a custom chat model class

Prerequisites

This guide assumes familiarity with the following concepts:

This notebook goes over how to create a custom chat model wrapper, in case you want to use your own chat model or a different wrapper than one that is directly supported in LangChain.

There are a few required things that a chat model needs to implement after extending the SimpleChatModel class:

  • A _call method that takes in a list of messages and call options (which includes things like stop sequences), and returns a string.
  • A _llmType method that returns a string. Used for logging purposes only.

You can also implement the following optional method:

  • A _streamResponseChunks method that returns an AsyncGenerator and yields ChatGenerationChunks. This allows the LLM to support streaming outputs.

Let's implement a very simple custom chat model that just echoes back the first n characters of the input.

import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";

export interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}

export class CustomChatModel extends SimpleChatModel {
n: number;

constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}

_llmType() {
return "custom";
}

async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}

async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
// Trigger the appropriate callback for new chunks
await runManager?.handleLLMNewToken(letter);
}
}
}

We can now use this as any other chat model:

const chatModel = new CustomChatModel({ n: 4 });

await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
content: 'I am',
additional_kwargs: {}
}

And support streaming:

const stream = await chatModel.stream([["human", "I am an LLM"]]);

for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
content: 'I',
additional_kwargs: {}
}
AIMessageChunk {
content: ' ',
additional_kwargs: {}
}
AIMessageChunk {
content: 'a',
additional_kwargs: {}
}
AIMessageChunk {
content: 'm',
additional_kwargs: {}
}

Richer outputs​

If you want to take advantage of LangChain's callback system for functionality like token tracking, you can extend the BaseChatModel class and implement the lower level _generate method. It also takes a list of BaseMessages as input, but requires you to construct and return a ChatGeneration object that permits additional metadata. Here's an example:

import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";

export interface AdvancedCustomChatModelOptions
extends BaseChatModelCallOptions {}

export interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}

export class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;

static lc_name(): string {
return "AdvancedCustomChatModel";
}

constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}

_llmType(): string {
return "advanced_custom_chat_model";
}
}

This will pass the additional returned information in callback events and in the `streamEvents method:

const chatModel = new AdvancedCustomChatModel({ n: 4 });

const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v1",
});

for await (const event of eventStream) {
if (event.event === "on_llm_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_llm_end",
"name": "AdvancedCustomChatModel",
"run_id": "b500b98d-bee5-4805-9b92-532a491f5c70",
"tags": [],
"metadata": {},
"data": {
"output": {
"generations": [
[
{
"message": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"additional_kwargs": {}
}
},
"text": "I am"
}
]
],
"llmOutput": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
}

Tracing (advanced)​

If you are implementing a custom chat model and want to use it with a tracing service like LangSmith, you can automatically log params used for a given invocation by implementing the invocationParams() method on the model.

This method is purely optional, but anything it returns will be logged as metadata for the trace.

Here's one pattern you might use:

export interface CustomChatModelOptions extends BaseChatModelCallOptions {
// Some required or optional inner args
tools: Record<string, any>[];
}

export interface CustomChatModelParams extends BaseChatModelParams {
temperature: number;
}

export class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {
temperature: number;

static lc_name(): string {
return "CustomChatModel";
}

constructor(fields: CustomChatModelParams) {
super(fields);
this.temperature = fields.temperature;
}

// Anything returned in this method will be logged as metadata in the trace.
// It is common to pass it any options used to invoke the function.
invocationParams(options?: this["ParsedCallOptions"]) {
return {
tools: options?.tools,
n: this.n,
};
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
const additionalParams = this.invocationParams(options);
const content = await someAPIRequest(messages, additionalParams);
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: {},
};
}

_llmType(): string {
return "advanced_custom_chat_model";
}
}

Was this page helpful?


You can also leave detailed feedback on GitHub.