Writing a Minimal Inference Server in Go

Writing a Minimal Inference Server in Go

Python is the default language for inference servers, and for good reason: PyTorch, HuggingFace, and most ML tooling are Python-first. But if the rest of your stack is Go, you end up with a Python sidecar just to call model.forward(). That sidecar needs its own container, its own health checks, its own deployment pipeline, and its own debugging story.

This article shows how to run ONNX model inference directly from Go — no Python, no sidecar — using the ONNX Runtime Go bindings. We will build a minimal HTTP inference server, handle batching with goroutines, and benchmark it against an equivalent FastAPI implementation.

Why ONNX

ONNX (Open Neural Network Exchange) is the lingua franca of model deployment. PyTorch, TensorFlow, scikit-learn, and most other frameworks can export to ONNX. Once in ONNX format, a model can run on ONNX Runtime, which has first-class C bindings and, by extension, Go bindings.

The workflow is:

Train in PyTorch → export to ONNX → load in Go with onnxruntime-go → serve over HTTP

No Python at runtime. The model is a file. The runtime is a shared library linked into your Go binary.

Exporting from PyTorch

Before writing any Go, export your model. Here is the pattern for a simple text classifier:

import torch
from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("your-model").eval()
tokenizer = AutoTokenizer.from_pretrained("your-model")

# Create a sample input matching your expected shape
dummy_input = tokenizer(
    ["sample text"],
    return_tensors="pt",
    padding="max_length",
    max_length=128,
)

torch.onnx.export(
    model,
    (dummy_input["input_ids"], dummy_input["attention_mask"]),
    "model.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids":      {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "logits":         {0: "batch_size"},
    },
    opset_version=17,
)

The dynamic_axes declaration is important: it tells the ONNX exporter that batch size and sequence length are variable, not fixed. Without it, the model only accepts inputs of exactly the dummy shape.

Verify the export:

python -c "import onnx; onnx.checker.check_model('model.onnx'); print('OK')"

The Go Inference Server

Dependencies

go get github.com/yalue/onnxruntime_go

Download the ONNX Runtime shared library for your platform from the ONNX Runtime releases page and place libonnxruntime.so (Linux) or onnxruntime.dylib (macOS) alongside your binary, or set LD_LIBRARY_PATH to its location.

Project layout

inference-server/
├── main.go
├── model/
│   ├── runner.go      ← ONNX session management
│   └── tokenizer.go   ← pre/post processing
├── handler/
│   └── predict.go     ← HTTP handler
└── model.onnx

The model runner

// model/runner.go
package model

import (
    "fmt"
    "sync"

    ort "github.com/yalue/onnxruntime_go"
)

// Runner wraps an ONNX Runtime session and is safe for concurrent use.
type Runner struct {
    session  *ort.DynamicAdvancedSession
    mu       sync.Mutex  // onnxruntime sessions are not goroutine-safe by default
    inputNames  []string
    outputNames []string
}

func NewRunner(modelPath string) (*Runner, error) {
    // Initialise the ONNX Runtime environment (once per process).
    // This is idempotent — safe to call multiple times.
    ort.SetSharedLibraryPath("./libonnxruntime.so")
    if err := ort.InitializeEnvironment(); err != nil {
        return nil, fmt.Errorf("init ort: %w", err)
    }

    inputNames  := []string{"input_ids", "attention_mask"}
    outputNames := []string{"logits"}

    session, err := ort.NewDynamicAdvancedSession(
        modelPath,
        inputNames,
        outputNames,
        nil, // default session options
    )
    if err != nil {
        return nil, fmt.Errorf("create session: %w", err)
    }

    return &Runner{
        session:     session,
        inputNames:  inputNames,
        outputNames: outputNames,
    }, nil
}

// Run executes a forward pass. inputIDs and attentionMask are flat int64 slices.
// shape is [batch_size, sequence_length].
func (r *Runner) Run(
    inputIDs     []int64,
    attentionMask []int64,
    shape        []int64,
) ([]float32, error) {
    // Create input tensors.
    idTensor, err := ort.NewTensor(ort.NewShape(shape...), inputIDs)
    if err != nil {
        return nil, fmt.Errorf("create id tensor: %w", err)
    }
    defer idTensor.Destroy()

    maskTensor, err := ort.NewTensor(ort.NewShape(shape...), attentionMask)
    if err != nil {
        return nil, fmt.Errorf("create mask tensor: %w", err)
    }
    defer maskTensor.Destroy()

    // Output tensor — shape [batch_size, num_classes].
    batchSize := shape[0]
    outputData := make([]float32, batchSize*2) // 2 classes: ham / spam
    outputTensor, err := ort.NewTensor(
        ort.NewShape(batchSize, 2),
        outputData,
    )
    if err != nil {
        return nil, fmt.Errorf("create output tensor: %w", err)
    }
    defer outputTensor.Destroy()

    // ONNX Runtime sessions require serialised access unless
    // you create one session per goroutine (or use parallel execution providers).
    r.mu.Lock()
    err = r.session.Run(
        []ort.ArbitraryTensor{idTensor, maskTensor},
        []ort.ArbitraryTensor{outputTensor},
    )
    r.mu.Unlock()
    if err != nil {
        return nil, fmt.Errorf("run session: %w", err)
    }

    return outputTensor.GetData(), nil
}

func (r *Runner) Close() {
    r.session.Destroy()
    ort.DestroyEnvironment()
}

A few things worth noting in this code:

The mutex. ONNX Runtime sessions are not goroutine-safe by default. The options are serialising access with a mutex (as above), creating a pool of sessions (one per goroutine), or enabling the parallel execution provider. For low-to-medium throughput, a single session with a mutex is simplest and wastes less memory. For high throughput, a session pool (bounded by GOMAXPROCS) is better.

Tensor lifetimes. defer tensor.Destroy() is critical — the ONNX Runtime C library allocates these buffers, and Go’s garbage collector does not know about them.

Type safety. Go’s type system enforces that you pass []int64 for integer tensors and []float32 for float tensors at compile time. Mismatched types that would be silent runtime errors in Python are compiler errors here.

The HTTP handler

// handler/predict.go
package handler

import (
    "encoding/json"
    "math"
    "net/http"

    "yourmodule/model"
)

type PredictRequest struct {
    Texts []string `json:"texts"`
}

type PredictResponse struct {
    Predictions []Prediction `json:"predictions"`
}

type Prediction struct {
    Label      string  `json:"label"`
    Confidence float64 `json:"confidence"`
}

type PredictHandler struct {
    runner    *model.Runner
    tokenizer *model.Tokenizer
}

func NewPredictHandler(r *model.Runner, t *model.Tokenizer) *PredictHandler {
    return &PredictHandler{runner: r, tokenizer: t}
}

func (h *PredictHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    var req PredictRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        http.Error(w, "invalid request", http.StatusBadRequest)
        return
    }
    if len(req.Texts) == 0 {
        http.Error(w, "texts must not be empty", http.StatusBadRequest)
        return
    }

    // Tokenise the batch.
    inputIDs, attentionMask, shape, err := h.tokenizer.Encode(req.Texts, 128)
    if err != nil {
        http.Error(w, "tokenisation failed", http.StatusInternalServerError)
        return
    }

    // Run the model.
    logits, err := h.runner.Run(inputIDs, attentionMask, shape)
    if err != nil {
        http.Error(w, "inference failed", http.StatusInternalServerError)
        return
    }

    // Convert logits to predictions.
    resp := PredictResponse{Predictions: make([]Prediction, len(req.Texts))}
    for i := range req.Texts {
        spam := softmax(logits[i*2], logits[i*2+1])
        label := "ham"
        if spam > 0.5 {
            label = "spam"
        }
        resp.Predictions[i] = Prediction{Label: label, Confidence: math.Round(spam*1000) / 1000}
    }

    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(resp)
}

func softmax(a, b float32) float64 {
    ea := math.Exp(float64(a))
    eb := math.Exp(float64(b))
    return eb / (ea + eb)
}

Wiring it together

// main.go
package main

import (
    "log"
    "net/http"

    "yourmodule/handler"
    "yourmodule/model"
)

func main() {
    runner, err := model.NewRunner("model.onnx")
    if err != nil {
        log.Fatalf("load model: %v", err)
    }
    defer runner.Close()

    tokenizer, err := model.NewTokenizer("tokenizer.json")
    if err != nil {
        log.Fatalf("load tokenizer: %v", err)
    }

    mux := http.NewServeMux()
    mux.Handle("/predict", handler.NewPredictHandler(runner, tokenizer))
    mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
        w.WriteHeader(http.StatusOK)
    })

    log.Println("listening on :8080")
    log.Fatal(http.ListenAndServe(":8080", mux))
}

Latency vs. Throughput: The Batching Trade-off

A single-session server with a mutex is latency-optimal for one request at a time: each request goes to the model immediately, with no waiting. But it is throughput-suboptimal under load: requests queue behind each other, and the model processes them one at a time.

For throughput-critical deployments, implement a micro-batcher:

type BatchRunner struct {
    runner   *model.Runner
    requests chan batchRequest
}

type batchRequest struct {
    texts  []string
    result chan batchResult
}

type batchResult struct {
    predictions []Prediction
    err         error
}

func (b *BatchRunner) Start(maxBatch int, maxWait time.Duration) {
    go func() {
        for {
            // Wait for the first request.
            first := <-b.requests
            batch := []batchRequest{first}
            deadline := time.After(maxWait)

            // Collect more requests up to maxBatch or maxWait.
        collect:
            for len(batch) < maxBatch {
                select {
                case req := <-b.requests:
                    batch = append(batch, req)
                case <-deadline:
                    break collect
                }
            }

            // Run the batch.
            allTexts := make([]string, 0)
            for _, req := range batch {
                allTexts = append(allTexts, req.texts...)
            }
            // ... encode and run allTexts together ...
            // ... distribute results back to each req.result channel ...
        }
    }()
}

The micro-batcher trades latency (requests wait up to maxWait) for throughput (the model processes maxBatch items per forward pass). On transformer models, batching to 8–16 items often doubles or triples throughput with only 5–10 ms of added latency — well within the budget of most services.

Benchmark: Go Server vs. FastAPI

The numbers below are from a single author-run comparison — not a published benchmark suite. Treat them as directional, not guaranteed on your hardware.

On an AMD EPYC 7763 server with an NVIDIA A10G, using a BERT-base classifier, 128-token sequences, batch size 8:

MetricGo + onnxruntime-goFastAPI + ONNX Runtime
p50 latency8 ms11 ms
p99 latency14 ms22 ms
Throughput (req/s)940680
Memory (RSS)420 MB890 MB
Cold start1.2 s4.8 s

The Go server is faster primarily because it avoids Python’s GIL contention in the HTTP layer and has lower per-request allocation overhead. The model itself runs at identical speed — the ONNX Runtime C library does the heavy lifting in both cases. The difference is in the wrapping.

Memory is roughly half: the Python FastAPI server carries the Python interpreter, NumPy, HuggingFace tokenizers (Python side), and uvicorn. The Go server carries none of these.

Cold start — the time from process launch to first successful request — is 4× faster in Go, primarily because there is no Python interpreter startup and no heavy import chain (FastAPI, uvicorn, HuggingFace tokenizers, NumPy).

The Mental Model

ONNX Runtime is a C library. Go talks to C libraries natively. The Python wrapper is not doing any ML — it is just calling the same C code Go can call directly.

The inference logic is in the ONNX Runtime shared library, not in Python. Exporting to ONNX moves the model out of the Python ecosystem into a format that any language with C FFI can serve. Go’s cgo support, stable ABI, and low-overhead HTTP stack make it a natural fit for serving these workloads without a Python intermediary.

Summary

ConcernGo approachNotes
Model formatONNXExport once from any framework
Runtimeonnxruntime_goThin CGo bindings to ONNX Runtime C library
ConcurrencyMutex or session poolPool scales better at high QPS
BatchingMicro-batcher with timeout2–3× throughput gain for transformer models
Type safetyEnforced at compile timeTensor shape mismatches are compiler errors
Memory~50% of Python equivalentNo interpreter, no NumPy, no HuggingFace Python

You lose the Python ML ecosystem at serving time. You gain Go’s concurrency model, memory efficiency, and operational simplicity. For teams already running Go services, the cost of the Python sidecar — in developer time, operational complexity, and runtime overhead — is rarely worth paying.