Rivet 3.1.9
RivetONNXrt.hh
1// -*- C++ -*-
2#ifndef RIVET_RivetONNXrt_HH
3#define RIVET_RivetONNXrt_HH
4
5#include <iostream>
6#include <functional>
7#include <numeric>
8
9#include "Rivet/Tools/RivetPaths.hh"
10#include "Rivet/Tools/Utils.hh"
11#include "onnxruntime/onnxruntime_cxx_api.h"
12
13
14namespace Rivet {
15
21
22 public:
23
24 // Suppress default constructor
25 RivetONNXrt() = delete;
26
28 RivetONNXrt(const string& filename, const string& runname = "RivetONNXrt") {
29
30 // Set some ORT variables that need to be kept in memory
31 _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
32
33 // Load the model
34 Ort::SessionOptions sessionopts;
35 _session = std::make_unique<Ort::Session> (*_env, filename.c_str(), sessionopts);
36
37 // Store network hyperparameters (input/output shape, etc.)
38 getNetworkInfo();
39
40 MSG_DEBUG(*this);
41 }
42
44 vector<vector<float>> compute(vector<vector<float>>& inputs) const {
45
47 if (inputs.size() != _inDims.size()) {
48 throw("Expected " + to_string(_inDims.size())
49 + " input nodes, received " + to_string(inputs.size()));
50 }
51
52 // Create input tensor objects from input data
53 vector<Ort::Value> ort_input;
54 ort_input.reserve(_inDims.size());
55 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
56 for (size_t i=0; i < _inDims.size(); ++i) {
57
58 // Check that input data matches expected input node dimension
59 if (inputs[i].size() != _inDimsFlat[i]) {
60 throw("Expected flattened input node dimension " + to_string(_inDimsFlat[i])
61 + ", received " + to_string(inputs[i].size()));
62 }
63
64 ort_input.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
65 inputs[i].data(), inputs[i].size(),
66 _inDims[i].data(), _inDims[i].size()));
67 }
68
69 // retrieve output tensors
70 auto ort_output = _session->Run(Ort::RunOptions{nullptr}, _inNames.data(),
71 ort_input.data(), ort_input.size(),
72 _outNames.data(), _outNames.size());
73
74 // construct flattened values and return
75 vector<vector<float>> outputs; outputs.resize(_outDims.size());
76 for (size_t i = 0; i < _outDims.size(); ++i) {
77 float* floatarr = ort_output[i].GetTensorMutableData<float>();
78 outputs[i].assign(floatarr, floatarr + _outDimsFlat[i]);
79 }
80 return outputs;
81 }
82
84 vector<float> compute(const vector<float>& inputs) const {
85 if (_inDims.size() != 1 || _outDims.size() != 1) {
86 throw("This method assumes a single input/output node!");
87 }
88 vector<vector<float>> wrapped_inputs = { inputs };
89 vector<vector<float>> outputs = compute(wrapped_inputs);
90 return outputs[0];
91 }
92
94 const bool hasKey(const std::string& key) const {
95 Ort::AllocatorWithDefaultOptions allocator;
96 return (bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
97 }
98
101 template <typename T>
102 const T retrieve(const std::string& key) const {
103 Ort::AllocatorWithDefaultOptions allocator;
104 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
105 if (!res) {
106 throw("Ket '"+key+"' not found in network metadata!");
107 }
108 /*if constexpr (std::is_same<T, std::string>::value) {
109 return res.get();
110 }*/
111 return lexical_cast<T>(res.get());
112 }
113
115 const std::string retrieve(const std::string& key) const {
116 Ort::AllocatorWithDefaultOptions allocator;
117 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
118 if (!res) {
119 throw("Ket '"+key+"' not found in network metadata!");
120 }
121 return res.get();
122 }
123
124 const std::string retrieve(const std::string& key, const std::string& defaultreturn) const {
125 try {
126 return retrieve(key);
127 } catch (...) {
128 return defaultreturn;
129 }
130 }
131
134 template <typename T>
135 const T retrieve(const std::string& key, const T& defaultreturn) const {
136 try {
137 return retrieve<T>(key);
138 } catch (...) {
139 return defaultreturn;
140 }
141 }
142
144 friend std::ostream& operator <<(std::ostream& os, const RivetONNXrt& rort){
145 os << "RivetONNXrt Network Summary: \n";
146 for (size_t i=0; i < rort._inNames.size(); ++i) {
147 os << "- Input node " << i << " name: " << rort._inNames[i];
148 os << ", dimensions: (";
149 for (size_t j=0; j < rort._inDims[i].size(); ++j){
150 if (j) os << ", ";
151 os << rort._inDims[i][j];
152 }
153 os << "), type (as ONNX enums): " << rort._inTypes[i] << "\n";
154 }
155 for (size_t i=0; i < rort._outNames.size(); ++i) {
156 os << "- Output node " << i << " name: " << rort._outNames[i];
157 os << ", dimensions: (";
158 for (size_t j=0; j < rort._outDims[i].size(); ++j){
159 if (j) os << ", ";
160 os << rort._outDims[i][j];
161 }
162 os << "), type (as ONNX enums): (" << rort._outTypes[i] << "\n";
163 }
164 return os;
165 }
166
168 Log& getLog() const {
169 string logname = "Rivet.RivetONNXrt";
170 return Log::getLog(logname);
171 }
172
173
174 private:
175
176 void getNetworkInfo() {
177
178 Ort::AllocatorWithDefaultOptions allocator;
179
180 // Retrieve network metadat
181 _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
182
183 // find out how many input nodes the model expects
184 const size_t num_input_nodes = _session->GetInputCount();
185 _inDimsFlat.reserve(num_input_nodes);
186 _inTypes.reserve(num_input_nodes);
187 _inDims.reserve(num_input_nodes);
188 _inNames.reserve(num_input_nodes);
189 _inNamesPtr.reserve(num_input_nodes);
190 for (size_t i = 0; i < num_input_nodes; ++i) {
191 // retrieve input node name
192 auto input_name = _session->GetInputNameAllocated(i, allocator);
193 _inNames.push_back(input_name.get());
194 _inNamesPtr.push_back(std::move(input_name));
195
196 // retrieve input node type
197 auto in_type_info = _session->GetInputTypeInfo(i);
198 auto in_tensor_info = in_type_info.GetTensorTypeAndShapeInfo();
199 _inTypes.push_back(in_tensor_info.GetElementType());
200 _inDims.push_back(in_tensor_info.GetShape());
201 }
202
203 // Fix negative shape values - appears to be an artefact of batch size issues.
204 for (auto& dims : _inDims) {
205 int64_t n = 1;
206 for (auto& dim : dims) {
207 if (dim < 0) dim = abs(dim);
208 n *= dim;
209 }
210 _inDimsFlat.push_back(n);
211 }
212
213 // find out how many output nodes the model expects
214 const size_t num_output_nodes = _session->GetOutputCount();
215 _outDimsFlat.reserve(num_output_nodes);
216 _outTypes.reserve(num_output_nodes);
217 _outDims.reserve(num_output_nodes);
218 _outNames.reserve(num_output_nodes);
219 _outNamesPtr.reserve(num_output_nodes);
220 for (size_t i = 0; i < num_output_nodes; ++i) {
221 // retrieve output node name
222 auto output_name = _session->GetOutputNameAllocated(i, allocator);
223 _outNames.push_back(output_name.get());
224 _outNamesPtr.push_back(std::move(output_name));
225
226 // retrieve input node type
227 auto out_type_info = _session->GetOutputTypeInfo(i);
228 auto out_tensor_info = out_type_info.GetTensorTypeAndShapeInfo();
229 _outTypes.push_back(out_tensor_info.GetElementType());
230 _outDims.push_back(out_tensor_info.GetShape());
231 }
232
233 // Fix negative shape values - appears to be an artefact of batch size issues.
234 for (auto& dims : _outDims) {
235 int64_t n = 1;
236 for (auto& dim : dims) {
237 if (dim < 0) dim = abs(dim);
238 n *= dim;
239 }
240 _outDimsFlat.push_back(n);
241 }
242 }
243
244 private:
245
247 std::unique_ptr<Ort::Env> _env;
248
250 std::unique_ptr<Ort::Session> _session;
251
253 std::unique_ptr<Ort::ModelMetadata> _metadata;
254
258 vector<vector<int64_t>> _inDims, _outDims;
259
261 vector<int64_t> _inDimsFlat, _outDimsFlat;
262
264 vector<ONNXTensorElementDataType> _inTypes, _outTypes;
265
267 vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
268
270 vector<const char*> _inNames, _outNames;
271 };
272}
273
274
275#endif
Logging system for controlled & formatted writing to stdout.
Definition Logging.hh:10
static Log & getLog(const std::string &name)
Simple interface class to take care of basic ONNX networks.
Definition RivetONNXrt.hh:20
Log & getLog() const
Logger.
Definition RivetONNXrt.hh:168
const bool hasKey(const std::string &key) const
Method to check if key exists in network metatdata.
Definition RivetONNXrt.hh:94
const T retrieve(const std::string &key) const
Definition RivetONNXrt.hh:102
const T retrieve(const std::string &key, const T &defaultreturn) const
Definition RivetONNXrt.hh:135
friend std::ostream & operator<<(std::ostream &os, const RivetONNXrt &rort)
Printing function for debugging.
Definition RivetONNXrt.hh:144
const std::string retrieve(const std::string &key) const
Template specialisation of retrieve for std::string.
Definition RivetONNXrt.hh:115
vector< float > compute(const vector< float > &inputs) const
Given a single-node input vector, populate and return the single-node output vector.
Definition RivetONNXrt.hh:84
RivetONNXrt(const string &filename, const string &runname="RivetONNXrt")
Constructor.
Definition RivetONNXrt.hh:28
vector< vector< float > > compute(vector< vector< float > > &inputs) const
Given a multi-node input vector, populate and return the multi-node output vector.
Definition RivetONNXrt.hh:44
#define MSG_DEBUG(x)
Debug messaging, not enabled by default, using MSG_LVL.
Definition Logging.hh:195
Definition MC_Cent_pPb.hh:10