ONNX 是什么
ONNX(开放神经网络交换)是一种开放标准格式,用于表示机器学习模型。它支持深度学习和传统机器学习模型,允许模型在不同框架和工具之间共享和使用。
什么是 ONNX
ONNX 是一种开放标准格式,旨在表示机器学习模型。它定义了一组通用的操作符和文件格式,使 AI 开发者能够将模型与各种框架、工具、运行时和编译器一起使用。这意味着您可以在一个框架(如 PyTorch)中训练模型,然后在另一个支持 ONNX 的框架(如 TensorFlow)中部署它,而无需重新编写代码。
支持的模型类型
令人惊讶的是,ONNX 不仅支持深度学习模型,还支持传统机器学习模型,如 scikit-learn 和 XGBoost。这扩展了其在不同机器学习场景中的适用性。
实际好处
ONNX 促进了框架之间的互操作性,使开发者更容易访问硬件优化,从而提高性能。它还支持云、边缘、网页和移动设备等各种平台.
使用 ONNX Runtime Java 库
要在 Java 中使用 ONNX 模型,您可以使用 Microsoft 的 ONNX Runtime Java 库。以下是如何在项目中添加和使用该库的步骤:
添加依赖项:在您的项目中添加 ONNX Runtime 的 Maven 依赖项。
1 2 3 4 5
| <dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.20.0</version> </dependency>
|
加载模型:使用 ONNX Runtime API 加载您的 ONNX 模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| import ai.onnxruntime.*; import java.util.*; import java.nio.FloatBuffer; import java.nio.LongBuffer;
public class OnnxExample { public static void main(String[] args) { try (OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions()) { opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); OrtSession session = env.createSession("path/to/your/model.onnx", opts); System.out.println("Model inputs: " + session.getInputNames()); System.out.println("Model outputs: " + session.getOutputNames()); } catch (OrtException e) { e.printStackTrace(); } } }
|
准备输入数据:根据模型的输入要求准备数据。
1 2 3 4 5 6 7 8
| float[] inputData = {1.0f, 2.0f, 3.0f, 4.0f}; long[] shape = {1, 4}; OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), shape);
Map<String, OnnxTensor> inputs = new HashMap<>(); inputs.put(session.getInputNames().iterator().next(), tensor);
|
运行推理:使用准备好的输入数据运行模型。
1 2 3 4 5 6 7 8 9 10 11
| try (OrtSession.Result results = session.run(inputs)) { OnnxTensor output = (OnnxTensor) results.get(0); float[] outputData = (float[]) output.getValue(); for (float value : outputData) { System.out.println("Output value: " + value); } }
|
资源清理:确保正确释放资源。
1 2 3 4
| tensor.close(); session.close(); env.close();
|
注意事项:
- 输入张量的形状(shape)需要根据模型的具体要求来设置
- 输入数据类型(float, int 等)需要与模型期望的类型匹配
- 使用 try-with-resources 语句确保资源正确释放
- 可以通过 SessionOptions 配置运行时选项,如优化级别、执行设备等
在Spring Boot中集成ONNX
Spring Boot是Java生态系统中流行的应用框架,将ONNX与Spring Boot集成可以创建强大的机器学习微服务。以下是如何在Spring Boot应用程序中实现ONNX模型推理的步骤。
项目设置
首先,创建一个Spring Boot项目并添加必要的依赖项:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| <dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.20.0</version> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> </dependencies>
|
创建ONNX服务
创建一个服务类来处理ONNX模型的加载和推理:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
| package com.example.onnxdemo.service;
import ai.onnxruntime.*; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service;
import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import java.nio.FloatBuffer; import java.util.HashMap; import java.util.Map;
@Service public class OnnxService {
@Value("${onnx.model.path}") private String modelPath; private OrtEnvironment environment; private OrtSession session; @PostConstruct public void initialize() throws OrtException { environment = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); session = environment.createSession(modelPath, sessionOptions); System.out.println("ONNX模型已加载,输入: " + session.getInputNames()); System.out.println("ONNX模型已加载,输出: " + session.getOutputNames()); } public float[] runInference(float[] inputData, long[] shape) throws OrtException { OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), shape); Map<String, OnnxTensor> inputs = new HashMap<>(); inputs.put(session.getInputNames().iterator().next(), tensor); try (OrtSession.Result results = session.run(inputs)) { OnnxTensor output = (OnnxTensor) results.get(0); return (float[]) output.getValue(); } finally { tensor.close(); } } @PreDestroy public void cleanup() throws OrtException { if (session != null) { session.close(); } if (environment != null) { environment.close(); } } }
|
创建REST API端点
创建一个控制器来暴露模型推理功能:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| package com.example.onnxdemo.controller;
import com.example.onnxdemo.dto.InferenceRequest; import com.example.onnxdemo.dto.InferenceResponse; import com.example.onnxdemo.service.OnnxService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController;
@RestController @RequestMapping("/api/inference") public class InferenceController {
@Autowired private OnnxService onnxService; @PostMapping public InferenceResponse runInference(@RequestBody InferenceRequest request) { try { float[] result = onnxService.runInference( request.getInputData(), request.getShape() ); return new InferenceResponse(result, null); } catch (Exception e) { return new InferenceResponse(null, e.getMessage()); } } }
|
创建请求和响应DTO
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| package com.example.onnxdemo.dto;
import lombok.Data;
@Data public class InferenceRequest { private float[] inputData; private long[] shape; }
@Data public class InferenceResponse { private float[] result; private String error; public InferenceResponse(float[] result, String error) { this.result = result; this.error = error; } }
|
配置应用程序属性
在application.properties或application.yml中配置模型路径:
1 2 3 4 5
| onnx.model.path=classpath:models/model.onnx
server.port=8080
|
实际应用场景
以下是在Spring Boot应用中使用ONNX的几个实际应用场景:
图像分类服务:使用预训练的图像分类模型(如ResNet或MobileNet)创建REST API,接收图像并返回分类结果。
自然语言处理:部署BERT或GPT等模型的ONNX版本,提供文本分类、情感分析或问答功能。
推荐系统:使用ONNX模型实现产品推荐功能,可以集成到电子商务平台。
异常检测:在金融或安全领域,使用ONNX模型检测异常交易或行为。
性能优化技巧
在Spring Boot应用中使用ONNX时,可以考虑以下性能优化技巧:
模型缓存:将模型加载到内存中并重用,避免重复加载。
批处理:尽可能批量处理请求,减少模型推理的次数。
异步处理:对于非实时需求,使用Spring的异步功能处理推理请求。
资源管理:正确管理ONNX资源,确保在不需要时释放。
监控和指标:使用Spring Boot Actuator监控模型性能和资源使用情况。
1 2 3 4 5 6 7 8 9 10
| public List<float[]> batchInference(List<float[]> inputBatch, long[] shape) throws OrtException { List<float[]> results = new ArrayList<>(); for (float[] input : inputBatch) { results.add(runInference(input, shape)); } return results; }
|
结论
ONNX为Java开发者提供了一种强大的方式来集成和部署机器学习模型,特别是在Spring Boot等企业级框架中。通过使用ONNX Runtime Java API,您可以轻松地将各种机器学习模型集成到您的Java应用程序中,无论这些模型最初是在哪个框架中训练的。
随着机器学习在企业应用中的普及,ONNX的互操作性优势将变得越来越重要,使开发者能够专注于创建价值,而不是解决技术兼容性问题。