After creating a Machine Learning model with Oracle Machine Learning for Python (aka OML4Py) and exporting it to ONNX (see article), I wanted to have more fun with it, this time with Helidon. I made some tests with SpringBoot and Micronaut but finally decided to write this to show how to publish the ONNX model as a RESTful service using Helidon.
Helidon is a cloud-native, open‑source set of Java libraries for writing microservices that run on a fast web core powered by Netty.
Helidon provides an application Starter to create a minimal structure for your project (very similar to the one provided by Spring or Micronaut)
Let me show the steps in the starter wizard.
Helidon comes in two flavours: SE and MP, providing similar functionality but offering different developer experiences. This is what the documentation says about them:
I’m going with Helidon MP although it will be a bit larger in size and less performant, but it’s similar to Spring Boot and the application will be based on modern enterprise Java standards such as Jakarta EE and MicroProfile (not really important for this scenario).
If you don’t have a preference, there’s a final note in the documentation:
If you don’t know which Helidon flavor to use — use Helidon MP.
In step 2, I’ll take the default as I don’t need these specific libraries in my project. I need ONNX libraries, that will be included later by hand.
I don’t have any preference, so I’ll go ahead with the Jakarta JSON Binding.
In the last step the name of the project and package names can be customized to your preferences. For the sake of simplicity, I’ll leave them by default.
That’s all! A zip file with the minimal project structure is generated and downloaded to your computer. Unzip it, and open it with your favorite IDE. In my case, I’m using VSCode (or Codium) and this is the final folders and files structure after removing some of the auto-generated examples and creating the Java classes for ONNX inferencing.
Let’s go bit by bit. First, let me show the Java ONNX libraries to add in pom.xml (more information here). As I don’t have GPU but a regular CPU, I’ll add the regular runtime library.
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.15.0</version>
</dependency>
I’m going to create a class named OnnxRT for the ONNX inferencing. It has a method: createSession() for creating the ONNX runtime session with basic options by loading the logreg_orders.onnx file that contains the ML model, and another method: getScore() that takes 6 float parameters for score calculation using the ONNX runtime session.
package com.example.myproject;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
public class OnnxRT {
private OrtSession session = null;
private OrtEnvironment env = null;
private static final Logger LOGGER = Logger.getLogger(OnnxRT.class.getName());
public OnnxRT(){
try {
createSession();
} catch (OrtException | IOException e) {
LOGGER.info("Couldn't create onnx session: " + e.getMessage());
}
}
private void createSession() throws OrtException, IOException{
String modelPath = "model/logreg_orders.onnx";
byte[] data = OnnxRT.class.getClassLoader()
.getResourceAsStream(modelPath)
.readAllBytes();
env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new SessionOptions();
opts.setOptimizationLevel(OptLevel.BASIC_OPT);
LOGGER.info("Loading model from " + modelPath);
session = env.createSession(data,opts);
}
public String getScore(float distance, float trip_duration,
float tmed, float prec, float wday, float items)
throws IOException, OrtException{
if (session == null){
createSession();
}
String inputName = session.getInputNames().iterator().next();
OnnxTensor test = OnnxTensor.createTensor(env,
new float[][] {{distance, trip_duration,
tmed, prec, wday, items}});
Map<String, OnnxTensor> container = new HashMap<>();
container.put(inputName, test);
Result output = session.run(container);
long[] labels = (long[]) output.get(0).getValue();
long predLabel = (long) labels[0];
LOGGER.info("Input values: " + distance + ", " +
trip_duration + ", " + tmed" +
", " + prec + ", " + wday + ", " +
items + " Prediction: "+ predLabel);
return "{ "Prediction": "+ predLabel+ " }";
}
}
Now, let me show the SimpleONNXResource class to serve the /simple-onnx requests. It’s based on the simple resource class generated by the Helidon stater. The ONNX scoring endpoint will be at /simple-onn/score; it expects 6 numeric parameters as part of the GET request, for instance, something like:
http://localhost:8080/simple-onnx/score?distance=18.52&trip_duration=39.99&tmed=4.1&prec=0&wday=2&items=2
@Path("/simple-onnx")
public class SimpleONNXResource {
private static final String PERSONALIZED_GETS_COUNTER_NAME = "personalizedGets";
private static final String PERSONALIZED_GETS_COUNTER_DESCRIPTION = "Counts personalized GET operations";
private static final String GETS_TIMER_NAME = "allGets";
private static final String GETS_TIMER_DESCRIPTION = "Tracks all GET operations";
private static final Logger LOGGER = Logger.getLogger(SimpleONNXResource.class.getName());
private final String message;
private OnnxRT onnxRuntime = new OnnxRT();
@Inject
public SimpleONNXResource(@ConfigProperty(name = "app.greeting") String message) {
this.message = message;
}
/**
* Return a worldly greeting message.
*
* @return {@link Message}
*/
@GET
@Produces(MediaType.APPLICATION_JSON)
public Message getDefaultMessage() {
String msg = "Welcome to ONNX runtime";
Message message = new Message();
message.setMessage(msg);
return message;
}
/**
* Return a scoring
*
* @return {@link Message}
*/
@Path("/score")
@GET
@Produces(MediaType.APPLICATION_JSON)
@Counted(name = PERSONALIZED_GETS_COUNTER_NAME,
absolute = true,
description = PERSONALIZED_GETS_COUNTER_DESCRIPTION)
@Timed(name = GETS_TIMER_NAME,
description = GETS_TIMER_DESCRIPTION,
unit = MetricUnits.SECONDS,
absolute = true)
public String getMessage(@QueryParam("distance") Float distance,
@QueryParam("trip_duration") Float trip_duration,
@QueryParam("tmed") Float tmed,
@QueryParam("prec") Float prec,
@QueryParam("wday") Float wday,
@QueryParam("items") Float items) {
LOGGER.info("Scoring request received");
try {
String scoring=onnxRuntime.getScore(distance,trip_duration, tmed, prec, wday, items);
Message message = new Message();
message.setMessage(scoring);
return message;
} catch (IOException | OrtException e) {
LOGGER.info("Processing exception " + e.getMessage());
Message message = new Message();
message.setMessage("Processing error, check the logs");
return message;
}
}
Let’s go one step backward, and check Helidon prerequisites before building the Helidon application.
So as for GraalVM, make sure Java Version is 17 before downloading it:
If you’re using Oracle Linux in OCI, GraalVM can be installed from the yum repositories:
# yum install graalvm22-ee-17-jdk.x86_64
And then, point the PATH to the installation folder (paths for the GraalVM rpm packaged version are shown below):
❯ export JAVA_HOME=/usr/lib64/graalvm/graalvm22-ee-java17
❯ export PATH=$JAVA_HOME/bin:$PATH
❯ which java
/usr/lib64/graalvm/graalvm22-ee-java17/bin/java
❯ java -version
java version "17.0.7" 2023-04-18 LTS
Java(TM) SE Runtime Environment GraalVM EE 22.3.2 (build 17.0.7+8-LTS-jvmci-22.3-b15)
Java HotSpot(TM) 64-Bit Server VM GraalVM EE 22.3.2 (build 17.0.7+8-LTS-jvmci-22.3-b15, mixed mode, sharing)
Let’s compile it and build a jar file.
❯ mvn package
[...]
[INFO] Building jar: /home/user/helidon-onnx/target/myproject.jar
[INFO] ------------------------------------------------------------------------
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 24.335 s
[INFO] Finished at: 2023-xx-xxTxx:xx:xxZ
[INFO] ------------------------------------------------------------------------
Time for a first execution and a couple of requests (make sure port 8080 is free before starting it):
❯ java -jar target/myproject.jar
2023.05.29 11:02:48 INFO io.helidon.common.LogConfig Thread[main,5,main]: Logging at initialization configured using classpath: /logging.properties
2023.05.29 11:02:51 INFO io.helidon.microprofile.server.ServerCdiExtension Thread[main,5,main]: Registering JAX-RS Application: HelidonMP
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
2023.05.29 11:02:52 INFO io.helidon.webserver.NettyWebServer Thread[nioEventLoopGroup-2-1,10,main]: Channel '@default' started: [id: 0x77f857f3, L:/[0:0:0:0:0:0:0:0]:8080]
2023.05.29 11:02:52 INFO io.helidon.microprofile.server.ServerCdiExtension Thread[main,5,main]: Server started on http://localhost:8080 (and all other host addresses) in 3412 milliseconds (since JVM startup).
❯ curl 'http://localhost:8080/simple-onnx/score?distance=18.52&trip_duration=39.99&tmed=4.1&prec=0&wday=2&items=2'
Let’s make some requests using curl now, just for fun:
❯ curl -s -X GET http://localhost:8080/health
{"status":"UP","checks":[]}
❯ curl 'http://localhost:8080/simple-onnx'
{"message":"Welcome to ONNX runtime"}
❯ curl 'http://localhost:8080/simple-onnx/score?distance=18.52&trip_duration=39.99&tmed=4.1&prec=0&wday=2&items=2'
{"message":"{ "Prediction": 1 }"}
By default, Helidon provides an endpoint and supporting logic for returning an OpenAPI document that describes the endpoints handled by the server, so let’s test it:
❯ curl -s http://localhost:8080/openapi
components:
schemas:
Message:
properties:
greeting:
type: string
message:
type: string
type: object
info:
title: Generated API
version: '1.0'
openapi: 3.0.3
paths:
/simple-onnx:
get:
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/Message'
description: OK
/simple-onnx/score:
get:
parameters:
-
in: query
name: distance
schema:
format: float
type: number
-
in: query
name: items
schema:
format: float
type: number
-
in: query
name: prec
schema:
format: float
type: number
-
in: query
name: tmed
schema:
format: float
type: number
-
in: query
name: trip_duration
schema:
format: float
type: number
-
in: query
name: wday
schema:
format: float
type: number
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/Message'
description: OK
It also comes with endpoints to get metrics in JSON/Prometheus format that provide a lot of useful information to monitor the service.
# Prometheus Format
curl -s -X GET http://localhost:8080/metrics
# JSON Format
curl -H 'Accept: application/json' -X GET http://localhost:8080/metrics
Mission accomplished! The ONNX model is exposed in a REST endpoint using Helidon and GraalVM.
Fun With Helidon, GraalVM and ONNX was originally published in Better Programming on Medium, where people are continuing the conversation by highlighting and responding to this story.