ONNX Runtime generate() Java API

Note: this API is in preview and is subject to change.

Overview

This document describes the Java API for ONNX Runtime GenAI.
Below are the main classes and methods, with code snippets and descriptions for each.


Install and import

The Java API is delivered by the ai.onnxruntime.genai Java package. Package publication is pending. To build the package from source, see the build from source guide.

import ai.onnxruntime.genai.*;

Model class

Constructor

Initializes a new model from the given model path.

public Model(String modelPath) throws GenAIException

createGeneratorParams

Creates a GeneratorParams instance for executing the model.

public GeneratorParams createGeneratorParams() throws GenAIException

createTokenizer

Creates a Tokenizer instance for this model.

public Tokenizer createTokenizer() throws GenAIException

generate

Generates output sequences using the provided generator parameters.

public Sequences generate(GeneratorParams generatorParams) throws GenAIException

Config class

Constructor

Initializes a new configuration object from a config path.

public Config(String configPath) throws GenAIException

clearProviders

Clears all providers from the configuration.

public void clearProviders() throws GenAIException

appendProvider

Appends a provider to the configuration.

public void appendProvider(String provider) throws GenAIException

setProviderOption

Sets a provider option in the configuration.

public void setProviderOption(String provider, String name, String value) throws GenAIException

overlay

Overlays a JSON string onto the configuration.

public void overlay(String json) throws GenAIException

Tokenizer class

Constructor

Initializes a tokenizer for the given model.

public Tokenizer(Model model) throws GenAIException

encode

Encodes a string into a sequence of token ids.

public Sequences encode(String string) throws GenAIException

encodeBatch

Encodes an array of strings into a sequence of token ids for each input.

public Sequences encodeBatch(String[] strings) throws GenAIException

decode

Decodes a sequence of token ids into text.

public String decode(int[] sequence) throws GenAIException

decodeBatch

Decodes a batch of sequences of token ids into text.

public String[] decodeBatch(Sequences sequences) throws GenAIException

createStream

Creates a TokenizerStream object for streaming tokenization.

public TokenizerStream createStream() throws GenAIException

TokenizerStream class

decode

Decodes a single token in the stream and returns the generated string chunk.

public String decode(int token) throws GenAIException

GeneratorParams class

Constructor

Initializes generator parameters for the given model.

public GeneratorParams(Model model) throws GenAIException

setSearchOption (double)

Sets a numeric search option.

public void setSearchOption(String optionName, double value) throws GenAIException

setSearchOption (boolean)

Sets a boolean search option.

public void setSearchOption(String optionName, boolean value) throws GenAIException

setInput (Sequences)

Sets the prompt(s) for model execution using sequences.

public void setInput(Sequences sequences) throws GenAIException

setInput (int[])

Sets the prompt(s) token ids for model execution.

public void setInput(int[] tokenIds, int sequenceLength, int batchSize) throws GenAIException

Generator class

Constructor

Constructs a Generator object with the given model and generator parameters.

public Generator(Model model, GeneratorParams generatorParams) throws GenAIException

isDone

Checks if the generation process is done.

public boolean isDone()

computeLogits

Computes the logits for the next token in the sequence.

public void computeLogits() throws GenAIException

generateNextToken

Generates the next token in the sequence.

public void generateNextToken() throws GenAIException

getSequence

Retrieves a sequence of token ids for the specified sequence index.

public int[] getSequence(long sequenceIndex) throws GenAIException

getLastTokenInSequence

Retrieves the last token in the sequence for the specified sequence index.

public int getLastTokenInSequence(long sequenceIndex) throws GenAIException

Sequences class

numSequences

Gets the number of sequences in the collection.

public long numSequences()

getSequence

Gets the sequence at the specified index.

public int[] getSequence(long sequenceIndex)

Tensor class

Constructor

Constructs a Tensor with the given data, shape, and element type.

public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException

Result class

isSuccess

Indicates if the operation was successful.

public boolean isSuccess()

getError

Gets the error message from a failed operation.

public String getError()

Utils class

setLogBool

Sets a boolean logging option.

public static void setLogBool(String name, boolean value)

setLogString

Sets a string logging option.

public static void setLogString(String name, String value)

setCurrentGpuDeviceId

Sets the current GPU device ID.

public static void setCurrentGpuDeviceId(int deviceId)

getCurrentGpuDeviceId

Gets the current GPU device ID.

public static int getCurrentGpuDeviceId()