Modernize invertergui: MQTT write support, HA integration, UI updates
Some checks failed
build / inverter_gui_pipeline (push) Has been cancelled

This commit is contained in:
2026-02-19 12:03:52 +11:00
parent 959d1e3c1f
commit a31a0b4829
460 changed files with 19655 additions and 40205 deletions

2
.gitignore vendored
View File

@@ -22,3 +22,5 @@ _testmain.go
*.exe
*.test
*.prof
vendor/

View File

@@ -1,22 +0,0 @@
run:
deadline: 10m
linters:
enable-all: false
enable:
- errcheck
- govet
- ineffassign
- typecheck
- dupl
- goconst
- gofmt
- unconvert
- revive
- staticcheck
- gosimple
- unused
issues:
max-per-linter: 0
max-same-issues: 0

108
README.md
View File

@@ -21,7 +21,7 @@ docker run --name invertergui --device /dev/ttyUSB0:/dev/ttyUSB0 -p 8080:8080 gh
## Requirements
This project makes use of [Go Modules](https://github.com/golang/go/wiki/Modules). The minimum version for Go is 1.16
This project makes use of [Go Modules](https://github.com/golang/go/wiki/Modules). The minimum supported version for Go is 1.22
## Getting started
@@ -37,10 +37,17 @@ Application Options:
--cli.enabled Enable CLI output. [$CLI_ENABLED]
--mqtt.enabled Enable MQTT publishing. [$MQTT_ENABLED]
--mqtt.broker= Set the host port and scheme of the MQTT broker. (default: tcp://localhost:1883) [$MQTT_BROKER]
--mqtt.client_id= Set the client ID for the MQTT connection. (default: interter-gui) [$MQTT_CLIENT_ID]
--mqtt.client_id= Set the client ID for the MQTT connection. (default: inverter-gui) [$MQTT_CLIENT_ID]
--mqtt.topic= Set the MQTT topic updates published to. (default: invertergui/updates) [$MQTT_TOPIC]
--mqtt.command_topic= Set the MQTT topic that receives write commands for Victron settings/RAM variables. (default: invertergui/settings/set) [$MQTT_COMMAND_TOPIC]
--mqtt.status_topic= Set the MQTT topic where write command status updates are published. (default: invertergui/settings/status) [$MQTT_STATUS_TOPIC]
--mqtt.ha.enabled Enable Home Assistant MQTT discovery integration. [$MQTT_HA_ENABLED]
--mqtt.ha.discovery_prefix= Set Home Assistant MQTT discovery prefix. (default: homeassistant) [$MQTT_HA_DISCOVERY_PREFIX]
--mqtt.ha.node_id= Set Home Assistant node ID used for discovery topics and unique IDs. (default: invertergui) [$MQTT_HA_NODE_ID]
--mqtt.ha.device_name= Set Home Assistant device display name. (default: Victron Inverter) [$MQTT_HA_DEVICE_NAME]
--mqtt.username= Set the MQTT username [$MQTT_USERNAME]
--mqtt.password= Set the MQTT password [$MQTT_PASSWORD]
--mqtt.password-file= Path to a file containing the MQTT password [$MQTT_PASSWORD_FILE]
--loglevel= The log level to generate logs at. ("panic", "fatal", "error", "warn", "info", "debug", "trace") (default: info) [$LOGLEVEL]
Help Options:
@@ -82,6 +89,20 @@ Battery Power: -0.659 W
Battery Charge: 100.000 %
```
The web UI also includes a **Remote Panel Control** section for:
- Remote Panel Mode (`on`, `off`, `charger_only`, `inverter_only`)
- Remote Panel Current Limit (AC input current limit in amps)
- Remote Panel Standby (prevent sleep while turned off)
The combined mode + current limit action maps to the same behavior as
`set_remote_panel_state` in `victron-mk3`.
The backing HTTP API endpoints are:
- `GET/POST /api/remote-panel/state`
- `GET/POST /api/remote-panel/standby`
### Munin
The Munin plugin location is at /munin (http://localhost:8080/munin).
@@ -281,16 +302,97 @@ The MQTT client will publish updates to the given broker at the set topic.
```bash
--mqtt.enabled Enable MQTT publishing. [$MQTT_ENABLED]
--mqtt.broker= Set the host port and scheme of the MQTT broker. (default: tcp://localhost:1883) [$MQTT_BROKER]
--mqtt.client_id= Set the client ID for the MQTT connection. (default: interter-gui) [$MQTT_CLIENT_ID]
--mqtt.client_id= Set the client ID for the MQTT connection. (default: inverter-gui) [$MQTT_CLIENT_ID]
--mqtt.topic= Set the MQTT topic updates published to. (default: invertergui/updates) [$MQTT_TOPIC]
--mqtt.command_topic= Set the MQTT topic that receives write commands for Victron settings/RAM variables. (default: invertergui/settings/set) [$MQTT_COMMAND_TOPIC]
--mqtt.status_topic= Set the MQTT topic where write command status updates are published. (default: invertergui/settings/status) [$MQTT_STATUS_TOPIC]
--mqtt.ha.enabled Enable Home Assistant MQTT discovery integration. [$MQTT_HA_ENABLED]
--mqtt.ha.discovery_prefix= Set Home Assistant MQTT discovery prefix. (default: homeassistant) [$MQTT_HA_DISCOVERY_PREFIX]
--mqtt.ha.node_id= Set Home Assistant node ID used for discovery topics and unique IDs. (default: invertergui) [$MQTT_HA_NODE_ID]
--mqtt.ha.device_name= Set Home Assistant device display name. (default: Victron Inverter) [$MQTT_HA_DEVICE_NAME]
--mqtt.username= Set the MQTT username [$MQTT_USERNAME]
--mqtt.password= Set the MQTT password [$MQTT_PASSWORD]
--mqtt.password-file= Path to a file containing the MQTT password [$MQTT_PASSWORD_FILE]
```
The MQTT client can be enabled by setting the environment variable `MQTT_ENABLED=true` or flag `--mqtt.enabled`.
All MQTT configuration can be done via flags or as environment variables.
The URI for the broker can be configured format should be `scheme://host:port`, where "scheme" is one of "tcp", "ssl", or "ws".
When `--mqtt.command_topic` is configured, the application subscribes to that topic and accepts JSON write commands.
The recommended command for inverter control follows the same model used by `victron-mk3`:
```json
{
"request_id": "optional-correlation-id",
"kind": "panel_state",
"switch": "on",
"current_limit": 16.5
}
```
`switch` supports `charger_only`, `inverter_only`, `on`, and `off` (or numeric values `1..4`).
`current_limit` is in amps and optional. If omitted, only the switch state is changed.
To update only the current limit (while preserving the last known mode), send:
```json
{
"kind": "panel_state",
"current_limit": 12.0
}
```
If no prior mode is known (for example on a fresh broker state), this command is rejected until a mode command is sent once.
Standby can be controlled with:
```json
{
"kind": "standby",
"standby": true
}
```
Low-level writes are still supported:
```json
{
"kind": "setting",
"id": 15,
"value": 1
}
```
`kind` supports `panel_state`, `setting`, and `ram_var` (with aliases for each).
The result is published to `--mqtt.status_topic` with status `ok` or `error`.
### Home Assistant
Enable Home Assistant auto-discovery with:
```bash
--mqtt.ha.enabled
```
When enabled, `invertergui` publishes retained discovery payloads and availability under:
- `{topic-root}/homeassistant/availability` (`online`/`offline`)
- `{discovery_prefix}/sensor/{node_id}/.../config`
- `{discovery_prefix}/binary_sensor/{node_id}/.../config`
- `{discovery_prefix}/select/{node_id}/remote_panel_mode/config` (if command topic is configured)
- `{discovery_prefix}/number/{node_id}/remote_panel_current_limit/config` (if command topic is configured)
- `{discovery_prefix}/switch/{node_id}/remote_panel_standby/config` (if command topic is configured)
The discovered entities include battery/input/output sensors, a data-valid diagnostic binary sensor,
plus remote panel controls for:
- `Remote Panel Mode` (`on`, `off`, `charger_only`, `inverter_only`)
- `Remote Panel Current Limit` (AC input current limit in amps)
- `Remote Panel Standby` (prevent device sleep while off)
The combined mode + current limit behavior is provided through the `panel_state` MQTT command kind,
which mirrors `victron_mk3.set_remote_panel_state`.
## TTY Device
The intertergui application makes use of a serial tty device to monitor the Multiplus.

View File

@@ -21,8 +21,16 @@ type config struct {
MQTT struct {
Enabled bool `long:"mqtt.enabled" env:"MQTT_ENABLED" description:"Enable MQTT publishing."`
Broker string `long:"mqtt.broker" env:"MQTT_BROKER" default:"tcp://localhost:1883" description:"Set the host port and scheme of the MQTT broker."`
ClientID string `long:"mqtt.client_id" env:"MQTT_CLIENT_ID" default:"interter-gui" description:"Set the client ID for the MQTT connection."`
ClientID string `long:"mqtt.client_id" env:"MQTT_CLIENT_ID" default:"inverter-gui" description:"Set the client ID for the MQTT connection."`
Topic string `long:"mqtt.topic" env:"MQTT_TOPIC" default:"invertergui/updates" description:"Set the MQTT topic updates published to."`
CommandTopic string `long:"mqtt.command_topic" env:"MQTT_COMMAND_TOPIC" default:"invertergui/settings/set" description:"Set the MQTT topic that receives write commands for Victron settings/RAM variables."`
StatusTopic string `long:"mqtt.status_topic" env:"MQTT_STATUS_TOPIC" default:"invertergui/settings/status" description:"Set the MQTT topic where write command status updates are published."`
HA struct {
Enabled bool `long:"mqtt.ha.enabled" env:"MQTT_HA_ENABLED" description:"Enable Home Assistant MQTT discovery integration."`
DiscoveryPrefix string `long:"mqtt.ha.discovery_prefix" env:"MQTT_HA_DISCOVERY_PREFIX" default:"homeassistant" description:"Set Home Assistant MQTT discovery prefix."`
NodeID string `long:"mqtt.ha.node_id" env:"MQTT_HA_NODE_ID" default:"invertergui" description:"Set Home Assistant node ID used for discovery topics and unique IDs."`
DeviceName string `long:"mqtt.ha.device_name" env:"MQTT_HA_DEVICE_NAME" default:"Victron Inverter" description:"Set Home Assistant device display name."`
}
Username string `long:"mqtt.username" env:"MQTT_USERNAME" default:"" description:"Set the MQTT username"`
Password string `long:"mqtt.password" env:"MQTT_PASSWORD" default:"" description:"Set the MQTT password"`
PasswordFile string `long:"mqtt.password-file" env:"MQTT_PASSWORD_FILE" default:"" description:"Path to a file containing the MQTT password"`

View File

@@ -37,14 +37,14 @@ import (
"net/http"
"os"
"github.com/diebietse/invertergui/mk2core"
"github.com/diebietse/invertergui/mk2driver"
"github.com/diebietse/invertergui/plugins/cli"
"github.com/diebietse/invertergui/plugins/mqttclient"
"github.com/diebietse/invertergui/plugins/munin"
"github.com/diebietse/invertergui/plugins/prometheus"
"github.com/diebietse/invertergui/plugins/webui"
"github.com/diebietse/invertergui/plugins/webui/static"
"invertergui/mk2core"
"invertergui/mk2driver"
"invertergui/plugins/cli"
"invertergui/plugins/mqttclient"
"invertergui/plugins/munin"
"invertergui/plugins/prometheus"
"invertergui/plugins/webui"
"invertergui/plugins/webui/static"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
"github.com/tarm/serial"
@@ -77,9 +77,15 @@ func main() {
}
// Webgui
gui := webui.NewWebGui(core.NewSubscription())
var writer mk2driver.SettingsWriter
if w, ok := mk2.(mk2driver.SettingsWriter); ok {
writer = w
}
gui := webui.NewWebGui(core.NewSubscription(), writer)
http.Handle("/", static.New())
http.Handle("/ws", http.HandlerFunc(gui.ServeHub))
http.Handle("/api/remote-panel/state", http.HandlerFunc(gui.ServeRemotePanelState))
http.Handle("/api/remote-panel/standby", http.HandlerFunc(gui.ServeRemotePanelStandby))
// Munin
mu := munin.NewMunin(core.NewSubscription())
@@ -95,11 +101,22 @@ func main() {
mqttConf := mqttclient.Config{
Broker: conf.MQTT.Broker,
Topic: conf.MQTT.Topic,
CommandTopic: conf.MQTT.CommandTopic,
StatusTopic: conf.MQTT.StatusTopic,
ClientID: conf.MQTT.ClientID,
HomeAssistant: mqttclient.HomeAssistantConfig{
Enabled: conf.MQTT.HA.Enabled,
DiscoveryPrefix: conf.MQTT.HA.DiscoveryPrefix,
NodeID: conf.MQTT.HA.NodeID,
DeviceName: conf.MQTT.HA.DeviceName,
},
Username: conf.MQTT.Username,
Password: conf.MQTT.Password,
}
if err := mqttclient.New(core.NewSubscription(), mqttConf); err != nil {
if writer == nil {
log.Warn("MK2 data source does not support write commands; MQTT command topic will be ignored")
}
if err := mqttclient.New(core.NewSubscription(), writer, mqttConf); err != nil {
log.Fatalf("Could not setup MQTT client: %v", err)
}
}

View File

@@ -1,4 +0,0 @@
coverage:
precision: 2
round: down
range: "65...100"

34
go.mod
View File

@@ -1,29 +1,31 @@
module github.com/diebietse/invertergui
module invertergui
go 1.22
go 1.26
require (
github.com/eclipse/paho.mqtt.golang v1.4.3
github.com/gorilla/websocket v1.5.1
github.com/jessevdk/go-flags v1.5.0
github.com/prometheus/client_golang v1.19.0
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.2
github.com/eclipse/paho.mqtt.golang v1.5.1
github.com/gorilla/websocket v1.5.3
github.com/jessevdk/go-flags v1.6.1
github.com/prometheus/client_golang v1.23.2
github.com/sirupsen/logrus v1.9.4
github.com/stretchr/testify v1.11.1
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.0 // indirect
github.com/prometheus/common v0.50.0 // indirect
github.com/prometheus/procfs v0.13.0 // indirect
golang.org/x/net v0.22.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.18.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

33
go.sum
View File

@@ -2,36 +2,57 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/eclipse/paho.mqtt.golang v1.4.3 h1:2kwcUGn8seMUfWndX0hGbvH8r7crgcJguQNCyp70xik=
github.com/eclipse/paho.mqtt.golang v1.4.3/go.mod h1:CSYvoAlsMkhYOXh/oKyxa8EcBci6dVkLCbo5tTC1RIE=
github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE=
github.com/eclipse/paho.mqtt.golang v1.5.1/go.mod h1:1/yJCneuyOoCOzKSsOTUc0AJfpsItBGWvYpBLimhArU=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=
github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bBK4=
github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU=
github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos=
github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.50.0 h1:YSZE6aa9+luNa2da6/Tik0q0A5AbR+U003TItK57CPQ=
github.com/prometheus/common v0.50.0/go.mod h1:wHFBCEVWVmHMUpg7pYcOm2QUR/ocQdYSJVQJKnHc3xQ=
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/procfs v0.13.0 h1:GqzLlQyfsPbaEHaQkO7tbDlriv/4o5Hudv6OXHGKX7o=
github.com/prometheus/procfs v0.13.0/go.mod h1:cd4PFCR54QLnGKPaKGA6l+cfuNXtht43ZKY6tow0Y1g=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -40,18 +61,30 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 h1:UyzmZLoiDWMRywV4DUYb9Fbt8uiOSooupjTq10vpvnU=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@@ -1,7 +1,7 @@
package mk2core
import (
"github.com/diebietse/invertergui/mk2driver"
"invertergui/mk2driver"
)
type Core struct {

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"math"
"sync"
"time"
@@ -49,13 +50,25 @@ const (
const (
acL1InfoFrame = 0x08
dcInfoFrame = 0x0C
interfaceFrame = 0x48 // H
setTargetFrame = 0x41
infoReqFrame = 0x46 //F
ledFrame = 0x4C
stateFrame = 0x53 // S
vFrame = 0x56
winmonFrame = 0x57
)
const (
panelStateVariant2Flags = 0x80
interfacePanelDetectFlag = 0x01
interfaceStandbyFlag = 0x02
panelCurrentLimitUnknown = 0x8000
panelCurrentLimitMax = 0x7FFF
)
// info frame types
const (
infoReqAddrDC = 0x00
@@ -65,12 +78,21 @@ const (
// winmon frame commands
const (
commandReadRAMVar = 0x30
commandWriteRAMVar = 0x32
commandWriteSetting = 0x33
commandWriteData = 0x34
commandGetRAMVarInfo = 0x36
commandUnsupportedResponse = 0x80
commandReadRAMResponse = 0x85
commandWriteRAMResponse = 0x87
commandWriteSettingResponse = 0x88
commandWriteNotAllowedResponse = 0x9B
commandGetRAMVarInfoResponse = 0x8E
)
const writeResponseTimeout = 3 * time.Second
type mk2Ser struct {
info *Mk2Info
p io.ReadWriter
@@ -79,6 +101,10 @@ type mk2Ser struct {
run chan struct{}
frameLock bool
infochan chan *Mk2Info
commandMu sync.Mutex
writeAck chan byte
stateAck chan struct{}
ifaceAck chan byte
wg sync.WaitGroup
}
@@ -89,6 +115,9 @@ func NewMk2Connection(dev io.ReadWriter) (Mk2, error) {
mk2.scaleCount = 0
mk2.frameLock = false
mk2.scales = make([]scaling, 0, ramVarMaxOffset)
mk2.writeAck = make(chan byte, 4)
mk2.stateAck = make(chan struct{}, 1)
mk2.ifaceAck = make(chan byte, 1)
mk2.setTarget()
mk2.run = make(chan struct{})
mk2.infochan = make(chan *Mk2Info)
@@ -153,6 +182,233 @@ func (m *mk2Ser) C() chan *Mk2Info {
return m.infochan
}
func (m *mk2Ser) WriteRAMVar(id uint16, value int16) error {
return m.writeByID(commandWriteRAMVar, commandWriteRAMResponse, id, value)
}
func (m *mk2Ser) WriteSetting(id uint16, value int16) error {
return m.writeByID(commandWriteSetting, commandWriteSettingResponse, id, value)
}
func (m *mk2Ser) SetPanelState(switchState PanelSwitchState, currentLimitA *float64) error {
if !validPanelSwitchState(switchState) {
return fmt.Errorf("invalid panel switch state: %d", switchState)
}
currentLimitRaw, err := encodePanelCurrentLimit(currentLimitA)
if err != nil {
return err
}
m.commandMu.Lock()
defer m.commandMu.Unlock()
m.clearStateResponses()
m.sendCommandLocked([]byte{
stateFrame,
byte(switchState),
byte(currentLimitRaw),
byte(currentLimitRaw >> 8),
0x01,
panelStateVariant2Flags,
})
return m.waitForStateResponse()
}
func (m *mk2Ser) SetStandby(enabled bool) error {
lineState := byte(interfacePanelDetectFlag)
if enabled {
lineState |= interfaceStandbyFlag
}
m.commandMu.Lock()
defer m.commandMu.Unlock()
m.clearInterfaceResponses()
m.sendCommandLocked([]byte{
interfaceFrame,
lineState,
})
return m.waitForInterfaceResponse(enabled)
}
func validPanelSwitchState(switchState PanelSwitchState) bool {
switch switchState {
case PanelSwitchChargerOnly, PanelSwitchInverterOnly, PanelSwitchOn, PanelSwitchOff:
return true
default:
return false
}
}
func encodePanelCurrentLimit(currentLimitA *float64) (uint16, error) {
if currentLimitA == nil {
return panelCurrentLimitUnknown, nil
}
if *currentLimitA < 0 {
return 0, fmt.Errorf("current_limit must be >= 0, got %.3f", *currentLimitA)
}
raw := math.Round(*currentLimitA * 10)
if raw > panelCurrentLimitMax {
return 0, fmt.Errorf("current_limit %.3f A is above MK2 maximum %.1f A", *currentLimitA, panelCurrentLimitMax/10.0)
}
return uint16(raw), nil
}
func (m *mk2Ser) writeByID(selectCommand, expectedResponse byte, id uint16, value int16) error {
m.commandMu.Lock()
defer m.commandMu.Unlock()
m.clearWriteResponses()
m.sendCommandLocked([]byte{
winmonFrame,
selectCommand,
byte(id),
byte(id >> 8),
})
rawValue := uint16(value)
m.sendCommandLocked([]byte{
winmonFrame,
commandWriteData,
byte(rawValue),
byte(rawValue >> 8),
})
return m.waitForWriteResponse(expectedResponse)
}
func (m *mk2Ser) clearWriteResponses() {
if m.writeAck == nil {
m.writeAck = make(chan byte, 4)
return
}
for {
select {
case <-m.writeAck:
default:
return
}
}
}
func (m *mk2Ser) waitForWriteResponse(expectedResponse byte) error {
if m.writeAck == nil {
return errors.New("write response channel is not initialized")
}
select {
case response := <-m.writeAck:
switch response {
case expectedResponse:
return nil
case commandUnsupportedResponse:
return errors.New("write command is not supported by this device firmware")
case commandWriteNotAllowedResponse:
return errors.New("write command rejected by device access level")
default:
return fmt.Errorf("unexpected write response 0x%02x", response)
}
case <-time.After(writeResponseTimeout):
return fmt.Errorf("timed out waiting for write response after %s", writeResponseTimeout)
}
}
func (m *mk2Ser) pushWriteResponse(response byte) {
if m.writeAck == nil {
return
}
select {
case m.writeAck <- response:
default:
}
}
func (m *mk2Ser) clearStateResponses() {
if m.stateAck == nil {
m.stateAck = make(chan struct{}, 1)
return
}
for {
select {
case <-m.stateAck:
default:
return
}
}
}
func (m *mk2Ser) waitForStateResponse() error {
if m.stateAck == nil {
return errors.New("panel state response channel is not initialized")
}
select {
case <-m.stateAck:
return nil
case <-time.After(writeResponseTimeout):
return fmt.Errorf("timed out waiting for panel state response after %s", writeResponseTimeout)
}
}
func (m *mk2Ser) pushStateResponse() {
if m.stateAck == nil {
return
}
select {
case m.stateAck <- struct{}{}:
default:
}
}
func (m *mk2Ser) clearInterfaceResponses() {
if m.ifaceAck == nil {
m.ifaceAck = make(chan byte, 1)
return
}
for {
select {
case <-m.ifaceAck:
default:
return
}
}
}
func (m *mk2Ser) waitForInterfaceResponse(expectedStandby bool) error {
if m.ifaceAck == nil {
return errors.New("interface response channel is not initialized")
}
select {
case lineState := <-m.ifaceAck:
standbyEnabled := lineState&interfaceStandbyFlag != 0
if standbyEnabled != expectedStandby {
return fmt.Errorf("unexpected standby line state 0x%02x", lineState)
}
return nil
case <-time.After(writeResponseTimeout):
return fmt.Errorf("timed out waiting for standby response after %s", writeResponseTimeout)
}
}
func (m *mk2Ser) pushInterfaceResponse(lineState byte) {
if m.ifaceAck == nil {
return
}
select {
case m.ifaceAck <- lineState:
default:
}
}
func (m *mk2Ser) readByte() byte {
buffer := make([]byte, 1)
_, err := io.ReadFull(m.p, buffer)
@@ -192,6 +448,12 @@ func (m *mk2Ser) handleFrame(l byte, frame []byte) {
m.setTarget()
case frameHeader:
switch frame[1] {
case interfaceFrame:
if len(frame) > 2 {
m.pushInterfaceResponse(frame[2])
}
case stateFrame:
m.pushStateResponse()
case vFrame:
m.versionDecode(frame[2:])
case winmonFrame:
@@ -200,6 +462,8 @@ func (m *mk2Ser) handleFrame(l byte, frame []byte) {
m.scaleDecode(frame[2:])
case commandReadRAMResponse:
m.stateDecode(frame[2:])
case commandWriteRAMResponse, commandWriteSettingResponse, commandUnsupportedResponse, commandWriteNotAllowedResponse:
m.pushWriteResponse(frame[2])
default:
logrus.Warnf("[handleFrame] invalid winmonFrame %v", frame[2:])
}
@@ -430,6 +694,12 @@ func getLEDs(ledsOn, ledsBlink byte) map[Led]LEDstate {
// Adds header and trailing crc for frame to send.
func (m *mk2Ser) sendCommand(data []byte) {
m.commandMu.Lock()
defer m.commandMu.Unlock()
m.sendCommandLocked(data)
}
func (m *mk2Ser) sendCommandLocked(data []byte) {
l := len(data)
dataOut := make([]byte, l+3)
dataOut[0] = byte(l + 1)

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -42,6 +43,7 @@ type testIo struct {
}
func NewIOStub(readBuffer []byte) io.ReadWriter {
writeBuffer = bytes.NewBuffer(nil)
return &testIo{
Reader: bytes.NewBuffer(readBuffer),
Writer: writeBuffer,
@@ -309,3 +311,127 @@ func Test_mk2Ser_calcFreq(t *testing.T) {
})
}
}
func Test_mk2Ser_WriteSetting(t *testing.T) {
testIO := NewIOStub(nil)
m := &mk2Ser{
p: testIO,
writeAck: make(chan byte, 1),
}
go func() {
time.Sleep(10 * time.Millisecond)
m.pushWriteResponse(commandWriteSettingResponse)
}()
err := m.WriteSetting(0x1234, 1234)
assert.NoError(t, err)
expected := []byte{
0x05, 0xff, 0x57, 0x33, 0x34, 0x12, 0x2c,
0x05, 0xff, 0x57, 0x34, 0xd2, 0x04, 0x9b,
}
assert.Equal(t, expected, writeBuffer.Bytes())
}
func Test_mk2Ser_WriteRAMVarRejected(t *testing.T) {
testIO := NewIOStub(nil)
m := &mk2Ser{
p: testIO,
writeAck: make(chan byte, 1),
}
go func() {
time.Sleep(10 * time.Millisecond)
m.pushWriteResponse(commandWriteNotAllowedResponse)
}()
err := m.WriteRAMVar(0x000d, 1)
assert.Error(t, err)
assert.ErrorContains(t, err, "rejected")
}
func Test_mk2Ser_SetPanelState(t *testing.T) {
testIO := NewIOStub(nil)
m := &mk2Ser{
p: testIO,
stateAck: make(chan struct{}, 1),
}
go func() {
time.Sleep(10 * time.Millisecond)
m.pushStateResponse()
}()
currentLimit := 16.5
err := m.SetPanelState(PanelSwitchOn, &currentLimit)
assert.NoError(t, err)
expected := []byte{
0x07, 0xff, 0x53, 0x03, 0xa5, 0x00, 0x01, 0x80, 0x7e,
}
assert.Equal(t, expected, writeBuffer.Bytes())
}
func Test_mk2Ser_SetPanelState_SwitchOnly(t *testing.T) {
testIO := NewIOStub(nil)
m := &mk2Ser{
p: testIO,
stateAck: make(chan struct{}, 1),
}
go func() {
time.Sleep(10 * time.Millisecond)
m.pushStateResponse()
}()
err := m.SetPanelState(PanelSwitchOff, nil)
assert.NoError(t, err)
expected := []byte{
0x07, 0xff, 0x53, 0x04, 0x00, 0x80, 0x01, 0x80, 0xa2,
}
assert.Equal(t, expected, writeBuffer.Bytes())
}
func Test_mk2Ser_SetStandby(t *testing.T) {
testIO := NewIOStub(nil)
m := &mk2Ser{
p: testIO,
ifaceAck: make(chan byte, 1),
}
go func() {
time.Sleep(10 * time.Millisecond)
m.pushInterfaceResponse(interfacePanelDetectFlag | interfaceStandbyFlag)
}()
err := m.SetStandby(true)
assert.NoError(t, err)
expected := []byte{
0x03, 0xff, 0x48, 0x03, 0xb3,
}
assert.Equal(t, expected, writeBuffer.Bytes())
}
func Test_mk2Ser_SetStandby_Disabled(t *testing.T) {
testIO := NewIOStub(nil)
m := &mk2Ser{
p: testIO,
ifaceAck: make(chan byte, 1),
}
go func() {
time.Sleep(10 * time.Millisecond)
m.pushInterfaceResponse(interfacePanelDetectFlag)
}()
err := m.SetStandby(false)
assert.NoError(t, err)
expected := []byte{
0x03, 0xff, 0x48, 0x01, 0xb5,
}
assert.Equal(t, expected, writeBuffer.Bytes())
}

View File

@@ -76,3 +76,29 @@ type Mk2 interface {
C() chan *Mk2Info
Close()
}
type PanelSwitchState byte
const (
// PanelSwitchChargerOnly enables charging only.
PanelSwitchChargerOnly PanelSwitchState = 0x01
// PanelSwitchInverterOnly enables inverter output and disables charging.
PanelSwitchInverterOnly PanelSwitchState = 0x02
// PanelSwitchOn enables both inverter and charger.
PanelSwitchOn PanelSwitchState = 0x03
// PanelSwitchOff disables inverter and charger.
PanelSwitchOff PanelSwitchState = 0x04
)
type SettingsWriter interface {
// WriteRAMVar writes a signed 16-bit value to a RAM variable id.
WriteRAMVar(id uint16, value int16) error
// WriteSetting writes a signed 16-bit value to a setting id.
WriteSetting(id uint16, value int16) error
// SetPanelState sends the MK2 "S" command using a virtual panel switch state.
// If currentLimitA is nil, the command does not update the AC current limit.
SetPanelState(switchState PanelSwitchState, currentLimitA *float64) error
// SetStandby configures the remote panel standby line.
// When enabled, the inverter is prevented from sleeping while switched off.
SetStandby(enabled bool) error
}

View File

@@ -37,6 +37,22 @@ func (m *mock) Close() {
}
func (m *mock) WriteRAMVar(_ uint16, _ int16) error {
return nil
}
func (m *mock) WriteSetting(_ uint16, _ int16) error {
return nil
}
func (m *mock) SetPanelState(_ PanelSwitchState, _ *float64) error {
return nil
}
func (m *mock) SetStandby(_ bool) error {
return nil
}
func (m *mock) genMockValues() {
mult := 1.0
ledState := LedOff

View File

@@ -1,7 +1,7 @@
package cli
import (
"github.com/diebietse/invertergui/mk2driver"
"invertergui/mk2driver"
"github.com/sirupsen/logrus"
)

View File

@@ -2,9 +2,14 @@ package mqttclient
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/diebietse/invertergui/mk2driver"
"invertergui/mk2driver"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/sirupsen/logrus"
)
@@ -13,42 +18,758 @@ var log = logrus.WithField("ctx", "inverter-gui-mqtt")
const keepAlive = 5 * time.Second
const (
commandKindSetting = "setting"
commandKindRAMVar = "ram_var"
commandKindPanel = "panel_state"
commandKindStandby = "standby"
writeStatusOK = "ok"
writeStatusError = "error"
)
type HomeAssistantConfig struct {
Enabled bool
DiscoveryPrefix string
NodeID string
DeviceName string
}
// Config sets MQTT client configuration
type Config struct {
Broker string
ClientID string
Topic string
CommandTopic string
StatusTopic string
HomeAssistant HomeAssistantConfig
Username string
Password string
}
// New creates an MQTT client that starts publishing MK2 data as it is received.
func New(mk2 mk2driver.Mk2, config Config) error {
type writeCommand struct {
RequestID string
Kind string
ID uint16
Value int16
HasSwitch bool
SwitchState mk2driver.PanelSwitchState
SwitchName string
CurrentLimitA *float64
Standby *bool
}
type writeCommandPayload struct {
RequestID string `json:"request_id"`
Kind string `json:"kind"`
Type string `json:"type"`
ID *uint16 `json:"id"`
Value json.RawMessage `json:"value"`
Switch string `json:"switch"`
SwitchState string `json:"switch_state"`
CurrentLimitA *float64 `json:"current_limit"`
Standby *bool `json:"standby"`
}
type writeStatus struct {
RequestID string `json:"request_id,omitempty"`
Status string `json:"status"`
Kind string `json:"kind,omitempty"`
ID uint16 `json:"id"`
Value int16 `json:"value"`
Switch string `json:"switch,omitempty"`
CurrentLimitA *float64 `json:"current_limit,omitempty"`
Standby *bool `json:"standby,omitempty"`
Error string `json:"error,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
type haDiscoveryDefinition struct {
Component string
ObjectID string
Config map[string]any
}
type panelStateCache struct {
mu sync.Mutex
hasSwitch bool
switchName string
switchState mk2driver.PanelSwitchState
}
// New creates an MQTT client that publishes MK2 updates and optionally handles setting write commands.
func New(mk2 mk2driver.Mk2, writer mk2driver.SettingsWriter, config Config) error {
c := mqtt.NewClient(getOpts(config))
if token := c.Connect(); token.Wait() && token.Error() != nil {
return token.Error()
}
cache := &panelStateCache{}
if config.HomeAssistant.Enabled {
if err := publishHAAvailability(c, config, "online"); err != nil {
return fmt.Errorf("could not publish Home Assistant availability: %w", err)
}
if err := publishHADiscovery(c, config); err != nil {
return fmt.Errorf("could not publish Home Assistant discovery payloads: %w", err)
}
if writer != nil {
if err := subscribeHAPanelModeState(c, config, cache); err != nil {
log.Warnf("Could not subscribe to Home Assistant panel mode state topic: %v", err)
}
}
}
if config.CommandTopic != "" {
if writer == nil {
log.Warnf("MQTT command topic %q configured, but no settings writer is available", config.CommandTopic)
} else {
t := c.Subscribe(config.CommandTopic, 1, commandHandler(c, writer, config, cache))
t.Wait()
if t.Error() != nil {
return fmt.Errorf("could not subscribe to MQTT command topic %q: %w", config.CommandTopic, t.Error())
}
log.Infof("Subscribed to MQTT command topic: %s", config.CommandTopic)
}
}
go func() {
for e := range mk2.C() {
if e.Valid {
data, err := json.Marshal(e)
if err != nil {
log.Errorf("Could not parse data source: %v", err)
if e == nil || !e.Valid {
continue
}
t := c.Publish(config.Topic, 0, false, data)
t.Wait()
if t.Error() != nil {
log.Errorf("Could not publish data: %v", t.Error())
}
if err := publishJSON(c, config.Topic, e, 0, false); err != nil {
log.Errorf("Could not publish update to MQTT topic %q: %v", config.Topic, err)
}
}
}()
return nil
}
func subscribeHAPanelModeState(client mqtt.Client, config Config, cache *panelStateCache) error {
if cache == nil {
return nil
}
stateTopic := haPanelSwitchStateTopic(config)
t := client.Subscribe(stateTopic, 1, func(_ mqtt.Client, msg mqtt.Message) {
switchState, switchName, err := normalizePanelSwitch(string(msg.Payload()))
if err != nil {
return
}
cache.remember(writeCommand{
Kind: commandKindPanel,
HasSwitch: true,
SwitchState: switchState,
SwitchName: switchName,
})
})
t.Wait()
return t.Error()
}
func commandHandler(client mqtt.Client, writer mk2driver.SettingsWriter, config Config, cache *panelStateCache) mqtt.MessageHandler {
if cache == nil {
cache = &panelStateCache{}
}
return func(_ mqtt.Client, msg mqtt.Message) {
cmd, err := decodeWriteCommand(msg.Payload())
if err != nil {
log.Errorf("Invalid MQTT write command payload from topic %q: %v", msg.Topic(), err)
publishWriteStatus(client, config.StatusTopic, writeStatus{
Status: writeStatusError,
Error: err.Error(),
Timestamp: time.Now().UTC(),
})
return
}
execCmd := cmd
status := writeStatus{
RequestID: cmd.RequestID,
Status: writeStatusOK,
Kind: cmd.Kind,
Timestamp: time.Now().UTC(),
}
switch cmd.Kind {
case commandKindPanel:
execCmd, err = cache.resolvePanelCommand(cmd)
if err != nil {
status.Status = writeStatusError
status.Error = err.Error()
log.Errorf("Invalid MQTT write command %s: %v", formatWriteCommandLog(cmd), err)
publishWriteStatus(client, config.StatusTopic, status)
return
}
status.Switch = execCmd.SwitchName
status.CurrentLimitA = execCmd.CurrentLimitA
case commandKindStandby:
status.Standby = copyBoolPtr(execCmd.Standby)
default:
status.ID = cmd.ID
status.Value = cmd.Value
}
if err := executeWriteCommand(writer, execCmd); err != nil {
status.Status = writeStatusError
status.Error = err.Error()
log.Errorf("Failed MQTT write command %s: %v", formatWriteCommandLog(execCmd), err)
} else {
log.Infof("Applied MQTT write command %s", formatWriteCommandLog(execCmd))
cache.remember(execCmd)
if config.HomeAssistant.Enabled {
if err := publishHAControlState(client, config, execCmd); err != nil {
log.Errorf("Could not publish Home Assistant control state update: %v", err)
}
}
}
publishWriteStatus(client, config.StatusTopic, status)
}
}
func (c *panelStateCache) resolvePanelCommand(cmd writeCommand) (writeCommand, error) {
if cmd.Kind != commandKindPanel {
return cmd, nil
}
if cmd.HasSwitch {
return cmd, nil
}
c.mu.Lock()
defer c.mu.Unlock()
if !c.hasSwitch {
return writeCommand{}, errors.New("panel_state command missing switch and no prior mode is known; set mode once first")
}
cmd.HasSwitch = true
cmd.SwitchName = c.switchName
cmd.SwitchState = c.switchState
return cmd, nil
}
func (c *panelStateCache) remember(cmd writeCommand) {
if cmd.Kind != commandKindPanel || !cmd.HasSwitch {
return
}
c.mu.Lock()
c.hasSwitch = true
c.switchName = cmd.SwitchName
c.switchState = cmd.SwitchState
c.mu.Unlock()
}
func decodeWriteCommand(payload []byte) (writeCommand, error) {
msg := writeCommandPayload{}
if err := json.Unmarshal(payload, &msg); err != nil {
return writeCommand{}, fmt.Errorf("invalid JSON payload: %w", err)
}
kind := msg.Kind
if kind == "" {
kind = msg.Type
}
normalizedKind, err := normalizeWriteKind(kind)
if err != nil {
return writeCommand{}, err
}
if normalizedKind == commandKindPanel {
switchName := msg.Switch
if switchName == "" {
switchName = msg.SwitchState
}
hasSwitch := false
switchState := mk2driver.PanelSwitchState(0)
normalizedSwitchName := ""
if switchName != "" {
var err error
switchState, normalizedSwitchName, err = normalizePanelSwitch(switchName)
if err != nil {
return writeCommand{}, err
}
hasSwitch = true
}
if msg.CurrentLimitA != nil && *msg.CurrentLimitA < 0 {
return writeCommand{}, fmt.Errorf("current_limit must be >= 0, got %.3f", *msg.CurrentLimitA)
}
if !hasSwitch && msg.CurrentLimitA == nil {
return writeCommand{}, errors.New(`missing required field "switch" (or "switch_state"), or "current_limit"`)
}
return writeCommand{
RequestID: msg.RequestID,
Kind: normalizedKind,
HasSwitch: hasSwitch,
SwitchState: switchState,
SwitchName: normalizedSwitchName,
CurrentLimitA: msg.CurrentLimitA,
}, nil
}
if normalizedKind == commandKindStandby {
standby, err := decodeStandbyValue(msg)
if err != nil {
return writeCommand{}, err
}
return writeCommand{
RequestID: msg.RequestID,
Kind: normalizedKind,
Standby: &standby,
}, nil
}
if msg.ID == nil {
return writeCommand{}, errors.New(`missing required field "id"`)
}
value, err := decodeInt16Value(msg.Value)
if err != nil {
return writeCommand{}, err
}
return writeCommand{
RequestID: msg.RequestID,
Kind: normalizedKind,
ID: *msg.ID,
Value: value,
}, nil
}
func normalizeWriteKind(raw string) (string, error) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "setting", "settings":
return commandKindSetting, nil
case "ram", "ramvar", "ram_var", "ram-variable", "ramvariable":
return commandKindRAMVar, nil
case "panel", "panel_state", "switch", "remote_panel":
return commandKindPanel, nil
case "standby", "panel_standby", "remote_panel_standby":
return commandKindStandby, nil
default:
return "", fmt.Errorf("unsupported write command kind %q", raw)
}
}
func normalizePanelSwitch(raw string) (mk2driver.PanelSwitchState, string, error) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "1", "charger_only", "charger-only", "charger":
return mk2driver.PanelSwitchChargerOnly, "charger_only", nil
case "2", "inverter_only", "inverter-only", "inverter":
return mk2driver.PanelSwitchInverterOnly, "inverter_only", nil
case "3", "on":
return mk2driver.PanelSwitchOn, "on", nil
case "4", "off":
return mk2driver.PanelSwitchOff, "off", nil
default:
return 0, "", fmt.Errorf("unsupported panel switch %q", raw)
}
}
func executeWriteCommand(writer mk2driver.SettingsWriter, cmd writeCommand) error {
if writer == nil {
return errors.New("settings writer is not available")
}
switch cmd.Kind {
case commandKindPanel:
if !cmd.HasSwitch {
return errors.New("panel_state command requires a switch state")
}
return writer.SetPanelState(cmd.SwitchState, cmd.CurrentLimitA)
case commandKindStandby:
if cmd.Standby == nil {
return errors.New("standby command missing standby value")
}
return writer.SetStandby(*cmd.Standby)
case commandKindRAMVar:
return writer.WriteRAMVar(cmd.ID, cmd.Value)
case commandKindSetting:
return writer.WriteSetting(cmd.ID, cmd.Value)
default:
return fmt.Errorf("unsupported write command kind %q", cmd.Kind)
}
}
func formatWriteCommandLog(cmd writeCommand) string {
switch cmd.Kind {
case commandKindPanel:
switchName := cmd.SwitchName
if switchName == "" {
switchName = "<cached>"
}
if cmd.CurrentLimitA == nil {
return fmt.Sprintf("kind=%s switch=%s", cmd.Kind, switchName)
}
return fmt.Sprintf("kind=%s switch=%s current_limit=%.3f", cmd.Kind, switchName, *cmd.CurrentLimitA)
case commandKindStandby:
if cmd.Standby == nil {
return fmt.Sprintf("kind=%s standby=<unset>", cmd.Kind)
}
return fmt.Sprintf("kind=%s standby=%t", cmd.Kind, *cmd.Standby)
default:
return fmt.Sprintf("kind=%s id=%d value=%d", cmd.Kind, cmd.ID, cmd.Value)
}
}
func decodeInt16Value(raw json.RawMessage) (int16, error) {
if len(raw) == 0 {
return 0, errors.New(`missing required field "value"`)
}
var value int16
if err := json.Unmarshal(raw, &value); err != nil {
return 0, fmt.Errorf(`field "value" must be a signed integer: %w`, err)
}
return value, nil
}
func decodeStandbyValue(msg writeCommandPayload) (bool, error) {
if msg.Standby != nil {
return *msg.Standby, nil
}
if len(msg.Value) == 0 {
return false, errors.New(`missing required field "standby" (or boolean "value")`)
}
var boolValue bool
if err := json.Unmarshal(msg.Value, &boolValue); err == nil {
return boolValue, nil
}
var stringValue string
if err := json.Unmarshal(msg.Value, &stringValue); err == nil {
switch strings.ToLower(strings.TrimSpace(stringValue)) {
case "1", "true", "on", "enable", "enabled":
return true, nil
case "0", "false", "off", "disable", "disabled":
return false, nil
}
}
var intValue int
if err := json.Unmarshal(msg.Value, &intValue); err == nil {
switch intValue {
case 1:
return true, nil
case 0:
return false, nil
}
}
return false, errors.New(`field "standby" must be true/false`)
}
func publishWriteStatus(client mqtt.Client, topic string, status writeStatus) {
if topic == "" {
return
}
if err := publishJSON(client, topic, status, 1, false); err != nil {
log.Errorf("Could not publish command status to MQTT topic %q: %v", topic, err)
}
}
func publishHADiscovery(client mqtt.Client, config Config) error {
definitions := buildHADiscoveryDefinitions(config)
prefix := haDiscoveryPrefix(config)
nodeID := haNodeID(config)
for _, def := range definitions {
topic := fmt.Sprintf("%s/%s/%s/%s/config", prefix, def.Component, nodeID, def.ObjectID)
if err := publishJSON(client, topic, def.Config, 1, true); err != nil {
return fmt.Errorf("could not publish discovery for %s/%s: %w", def.Component, def.ObjectID, err)
}
}
return nil
}
func buildHADiscoveryDefinitions(config Config) []haDiscoveryDefinition {
if !config.HomeAssistant.Enabled {
return nil
}
nodeID := haNodeID(config)
device := map[string]any{
"identifiers": []string{fmt.Sprintf("invertergui_%s", nodeID)},
"name": haDeviceName(config),
"manufacturer": "Victron Energy",
"model": "MultiPlus",
"sw_version": "invertergui",
}
availabilityTopic := haAvailabilityTopic(config)
stateTopic := config.Topic
sensors := []haDiscoveryDefinition{
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "battery_voltage", "Battery Voltage", "{{ value_json.BatVoltage }}", "V", "voltage", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "battery_current", "Battery Current", "{{ value_json.BatCurrent }}", "A", "current", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "battery_charge", "Battery Charge", "{{ ((value_json.ChargeState | float(0)) * 100) | round(1) }}", "%", "battery", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "input_voltage", "Input Voltage", "{{ value_json.InVoltage }}", "V", "voltage", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "input_current", "Input Current", "{{ value_json.InCurrent }}", "A", "current", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "input_frequency", "Input Frequency", "{{ value_json.InFrequency }}", "Hz", "frequency", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "output_voltage", "Output Voltage", "{{ value_json.OutVoltage }}", "V", "voltage", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "output_current", "Output Current", "{{ value_json.OutCurrent }}", "A", "current", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "output_frequency", "Output Frequency", "{{ value_json.OutFrequency }}", "Hz", "frequency", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "input_power", "Input Power", "{{ ((value_json.InVoltage | float(0)) * (value_json.InCurrent | float(0))) | round(1) }}", "VA", "", "measurement"),
buildHASensor(device, availabilityTopic, stateTopic, nodeID, "output_power", "Output Power", "{{ ((value_json.OutVoltage | float(0)) * (value_json.OutCurrent | float(0))) | round(1) }}", "VA", "", "measurement"),
{
Component: "binary_sensor",
ObjectID: "data_valid",
Config: map[string]any{
"name": "Data Valid",
"unique_id": fmt.Sprintf("%s_data_valid", nodeID),
"state_topic": stateTopic,
"value_template": "{{ value_json.Valid }}",
"payload_on": "true",
"payload_off": "false",
"availability_topic": availabilityTopic,
"device": device,
"entity_category": "diagnostic",
},
},
}
if config.CommandTopic != "" {
sensors = append(sensors,
haDiscoveryDefinition{
Component: "select",
ObjectID: "remote_panel_mode",
Config: map[string]any{
"name": "Remote Panel Mode",
"unique_id": fmt.Sprintf("%s_remote_panel_mode", nodeID),
"state_topic": haPanelSwitchStateTopic(config),
"command_topic": config.CommandTopic,
"command_template": "{\"kind\":\"panel_state\",\"switch\":\"{{ value }}\"}",
"options": []string{"charger_only", "inverter_only", "on", "off"},
"availability_topic": availabilityTopic,
"device": device,
"icon": "mdi:transmission-tower-export",
},
},
haDiscoveryDefinition{
Component: "number",
ObjectID: "remote_panel_current_limit",
Config: map[string]any{
"name": "Remote Panel Current Limit",
"unique_id": fmt.Sprintf("%s_remote_panel_current_limit", nodeID),
"state_topic": haCurrentLimitStateTopic(config),
"command_topic": config.CommandTopic,
"command_template": "{\"kind\":\"panel_state\",\"current_limit\":{{ value | float(0) }}}",
"unit_of_measurement": "A",
"device_class": "current",
"mode": "box",
"min": 0,
"max": 100,
"step": 0.1,
"availability_topic": availabilityTopic,
"device": device,
"icon": "mdi:current-ac",
},
},
haDiscoveryDefinition{
Component: "switch",
ObjectID: "remote_panel_standby",
Config: map[string]any{
"name": "Remote Panel Standby",
"unique_id": fmt.Sprintf("%s_remote_panel_standby", nodeID),
"state_topic": haStandbyStateTopic(config),
"command_topic": config.CommandTopic,
"payload_on": "{\"kind\":\"standby\",\"standby\":true}",
"payload_off": "{\"kind\":\"standby\",\"standby\":false}",
"state_on": "ON",
"state_off": "OFF",
"availability_topic": availabilityTopic,
"device": device,
"icon": "mdi:power-sleep",
},
},
)
}
return sensors
}
func buildHASensor(device map[string]any, availabilityTopic, stateTopic, nodeID, objectID, name, valueTemplate, unit, deviceClass, stateClass string) haDiscoveryDefinition {
config := map[string]any{
"name": name,
"unique_id": fmt.Sprintf("%s_%s", nodeID, objectID),
"state_topic": stateTopic,
"value_template": valueTemplate,
"availability_topic": availabilityTopic,
"device": device,
}
if unit != "" {
config["unit_of_measurement"] = unit
}
if deviceClass != "" {
config["device_class"] = deviceClass
}
if stateClass != "" {
config["state_class"] = stateClass
}
return haDiscoveryDefinition{
Component: "sensor",
ObjectID: objectID,
Config: config,
}
}
func publishHAAvailability(client mqtt.Client, config Config, status string) error {
return publishString(client, haAvailabilityTopic(config), status, 1, true)
}
func publishHAControlState(client mqtt.Client, config Config, cmd writeCommand) error {
switch cmd.Kind {
case commandKindPanel:
if err := publishString(client, haPanelSwitchStateTopic(config), cmd.SwitchName, 1, true); err != nil {
return err
}
if cmd.CurrentLimitA != nil {
limit := strconv.FormatFloat(*cmd.CurrentLimitA, 'f', 1, 64)
if err := publishString(client, haCurrentLimitStateTopic(config), limit, 1, true); err != nil {
return err
}
}
case commandKindStandby:
if cmd.Standby == nil {
return nil
}
state := "OFF"
if *cmd.Standby {
state = "ON"
}
if err := publishString(client, haStandbyStateTopic(config), state, 1, true); err != nil {
return err
}
}
return nil
}
func publishJSON(client mqtt.Client, topic string, payload any, qos byte, retained bool) error {
if topic == "" {
return errors.New("topic is empty")
}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("could not marshal payload: %w", err)
}
t := client.Publish(topic, qos, retained, data)
t.Wait()
if t.Error() != nil {
return t.Error()
}
return nil
}
func publishString(client mqtt.Client, topic, payload string, qos byte, retained bool) error {
if topic == "" {
return errors.New("topic is empty")
}
t := client.Publish(topic, qos, retained, payload)
t.Wait()
if t.Error() != nil {
return t.Error()
}
return nil
}
func mqttTopicRoot(topic string) string {
t := strings.Trim(strings.TrimSpace(topic), "/")
if t == "" {
return "invertergui"
}
if strings.HasSuffix(t, "/updates") {
root := strings.TrimSuffix(t, "/updates")
if root != "" {
return root
}
}
return t
}
func haAvailabilityTopic(config Config) string {
return fmt.Sprintf("%s/homeassistant/availability", mqttTopicRoot(config.Topic))
}
func haPanelSwitchStateTopic(config Config) string {
return fmt.Sprintf("%s/homeassistant/remote_panel_mode/state", mqttTopicRoot(config.Topic))
}
func haCurrentLimitStateTopic(config Config) string {
return fmt.Sprintf("%s/homeassistant/remote_panel_current_limit/state", mqttTopicRoot(config.Topic))
}
func haStandbyStateTopic(config Config) string {
return fmt.Sprintf("%s/homeassistant/remote_panel_standby/state", mqttTopicRoot(config.Topic))
}
func haDiscoveryPrefix(config Config) string {
prefix := strings.Trim(strings.TrimSpace(config.HomeAssistant.DiscoveryPrefix), "/")
if prefix == "" {
return "homeassistant"
}
return prefix
}
func haNodeID(config Config) string {
nodeID := normalizeID(config.HomeAssistant.NodeID)
if nodeID == "" {
nodeID = normalizeID(config.ClientID)
}
if nodeID == "" {
return "invertergui"
}
return nodeID
}
func haDeviceName(config Config) string {
name := strings.TrimSpace(config.HomeAssistant.DeviceName)
if name == "" {
return "Victron Inverter"
}
return name
}
func normalizeID(in string) string {
trimmed := strings.TrimSpace(strings.ToLower(in))
if trimmed == "" {
return ""
}
var b strings.Builder
lastUnderscore := false
for _, r := range trimmed {
valid := (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_'
if valid {
b.WriteRune(r)
lastUnderscore = false
continue
}
if !lastUnderscore {
b.WriteRune('_')
lastUnderscore = true
}
}
return strings.Trim(b.String(), "_")
}
func copyBoolPtr(in *bool) *bool {
if in == nil {
return nil
}
value := *in
return &value
}
func getOpts(config Config) *mqtt.ClientOptions {
opts := mqtt.NewClientOptions()
opts.AddBroker(config.Broker)
@@ -60,6 +781,9 @@ func getOpts(config Config) *mqtt.ClientOptions {
if config.Password != "" {
opts.SetPassword(config.Password)
}
if config.HomeAssistant.Enabled {
opts.SetWill(haAvailabilityTopic(config), "offline", 1, true)
}
opts.SetKeepAlive(keepAlive)
opts.SetOnConnectHandler(func(mqtt.Client) {
@@ -67,7 +791,6 @@ func getOpts(config Config) *mqtt.ClientOptions {
})
opts.SetConnectionLostHandler(func(_ mqtt.Client, err error) {
log.Errorf("Client connection to broker lost: %v", err)
})
return opts
}

View File

@@ -0,0 +1,316 @@
package mqttclient
import (
"testing"
"invertergui/mk2driver"
"github.com/stretchr/testify/assert"
)
type fakeWriter struct {
lastKind string
lastID uint16
lastValue int16
lastSwitchState mk2driver.PanelSwitchState
lastCurrentLimit *float64
lastStandby *bool
err error
}
func (f *fakeWriter) WriteRAMVar(id uint16, value int16) error {
f.lastKind = commandKindRAMVar
f.lastID = id
f.lastValue = value
return f.err
}
func (f *fakeWriter) WriteSetting(id uint16, value int16) error {
f.lastKind = commandKindSetting
f.lastID = id
f.lastValue = value
return f.err
}
func (f *fakeWriter) SetPanelState(switchState mk2driver.PanelSwitchState, currentLimitA *float64) error {
f.lastKind = commandKindPanel
f.lastSwitchState = switchState
f.lastCurrentLimit = currentLimitA
return f.err
}
func (f *fakeWriter) SetStandby(standby bool) error {
f.lastKind = commandKindStandby
f.lastStandby = &standby
return f.err
}
func Test_decodeWriteCommand(t *testing.T) {
currentLimit := 16.5
tests := []struct {
name string
payload string
check func(*testing.T, writeCommand)
wantErr string
}{
{
name: "setting",
payload: `{"request_id":"abc","kind":"setting","id":15,"value":-5}`,
check: func(t *testing.T, got writeCommand) {
assert.Equal(t, writeCommand{
RequestID: "abc",
Kind: commandKindSetting,
ID: 15,
Value: -5,
}, got)
},
},
{
name: "ram_var alias from type",
payload: `{"type":"ramvar","id":2,"value":7}`,
check: func(t *testing.T, got writeCommand) {
assert.Equal(t, writeCommand{
Kind: commandKindRAMVar,
ID: 2,
Value: 7,
}, got)
},
},
{
name: "panel state",
payload: `{"kind":"panel_state","switch":"on","current_limit":16.5}`,
check: func(t *testing.T, got writeCommand) {
assert.Equal(t, commandKindPanel, got.Kind)
assert.True(t, got.HasSwitch)
assert.Equal(t, mk2driver.PanelSwitchOn, got.SwitchState)
assert.Equal(t, "on", got.SwitchName)
if assert.NotNil(t, got.CurrentLimitA) {
assert.Equal(t, currentLimit, *got.CurrentLimitA)
}
},
},
{
name: "panel current limit only",
payload: `{"kind":"panel_state","current_limit":12.0}`,
check: func(t *testing.T, got writeCommand) {
assert.Equal(t, commandKindPanel, got.Kind)
assert.False(t, got.HasSwitch)
assert.Nil(t, got.Standby)
if assert.NotNil(t, got.CurrentLimitA) {
assert.Equal(t, 12.0, *got.CurrentLimitA)
}
},
},
{
name: "standby bool",
payload: `{"kind":"standby","standby":true}`,
check: func(t *testing.T, got writeCommand) {
assert.Equal(t, commandKindStandby, got.Kind)
if assert.NotNil(t, got.Standby) {
assert.True(t, *got.Standby)
}
},
},
{
name: "standby using value string",
payload: `{"kind":"standby","value":"OFF"}`,
check: func(t *testing.T, got writeCommand) {
assert.Equal(t, commandKindStandby, got.Kind)
if assert.NotNil(t, got.Standby) {
assert.False(t, *got.Standby)
}
},
},
{
name: "missing id",
payload: `{"kind":"setting","value":1}`,
wantErr: `missing required field "id"`,
},
{
name: "missing panel switch and current limit",
payload: `{"kind":"panel_state"}`,
wantErr: `missing required field "switch"`,
},
{
name: "invalid standby",
payload: `{"kind":"standby","value":"banana"}`,
wantErr: `field "standby" must be true/false`,
},
{
name: "invalid kind",
payload: `{"kind":"unknown","id":1,"value":1}`,
wantErr: `unsupported write command kind "unknown"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := decodeWriteCommand([]byte(tt.payload))
if tt.wantErr != "" {
assert.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
assert.NoError(t, err)
tt.check(t, got)
})
}
}
func Test_executeWriteCommand(t *testing.T) {
limit := 8.0
standby := true
tests := []struct {
name string
cmd writeCommand
want string
}{
{
name: "setting",
cmd: writeCommand{
Kind: commandKindSetting,
ID: 9,
Value: 2,
},
want: commandKindSetting,
},
{
name: "ram_var",
cmd: writeCommand{
Kind: commandKindRAMVar,
ID: 3,
Value: -1,
},
want: commandKindRAMVar,
},
{
name: "panel_state",
cmd: writeCommand{
Kind: commandKindPanel,
HasSwitch: true,
SwitchState: mk2driver.PanelSwitchInverterOnly,
CurrentLimitA: &limit,
},
want: commandKindPanel,
},
{
name: "standby",
cmd: writeCommand{
Kind: commandKindStandby,
Standby: &standby,
},
want: commandKindStandby,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
writer := &fakeWriter{}
err := executeWriteCommand(writer, tt.cmd)
assert.NoError(t, err)
assert.Equal(t, tt.want, writer.lastKind)
switch tt.want {
case commandKindPanel:
assert.Equal(t, tt.cmd.SwitchState, writer.lastSwitchState)
if assert.NotNil(t, writer.lastCurrentLimit) {
assert.Equal(t, *tt.cmd.CurrentLimitA, *writer.lastCurrentLimit)
}
case commandKindStandby:
if assert.NotNil(t, writer.lastStandby) {
assert.Equal(t, *tt.cmd.Standby, *writer.lastStandby)
}
default:
assert.Equal(t, tt.cmd.ID, writer.lastID)
assert.Equal(t, tt.cmd.Value, writer.lastValue)
}
})
}
}
func Test_buildHADiscoveryDefinitions(t *testing.T) {
cfg := Config{
Topic: "invertergui/updates",
CommandTopic: "invertergui/settings/set",
HomeAssistant: HomeAssistantConfig{
Enabled: true,
DiscoveryPrefix: "homeassistant",
NodeID: "victron_main",
DeviceName: "Shed Victron",
},
}
definitions := buildHADiscoveryDefinitions(cfg)
assert.NotEmpty(t, definitions)
var panelMode *haDiscoveryDefinition
var panelCurrentLimit *haDiscoveryDefinition
var panelStandby *haDiscoveryDefinition
var batteryVoltage *haDiscoveryDefinition
for i := range definitions {
def := &definitions[i]
if def.Component == "select" && def.ObjectID == "remote_panel_mode" {
panelMode = def
}
if def.Component == "number" && def.ObjectID == "remote_panel_current_limit" {
panelCurrentLimit = def
}
if def.Component == "switch" && def.ObjectID == "remote_panel_standby" {
panelStandby = def
}
if def.Component == "sensor" && def.ObjectID == "battery_voltage" {
batteryVoltage = def
}
}
if assert.NotNil(t, panelMode) {
assert.Equal(t, cfg.CommandTopic, panelMode.Config["command_topic"])
assert.Equal(t, haPanelSwitchStateTopic(cfg), panelMode.Config["state_topic"])
}
if assert.NotNil(t, panelCurrentLimit) {
assert.Equal(t, cfg.CommandTopic, panelCurrentLimit.Config["command_topic"])
assert.Equal(t, haCurrentLimitStateTopic(cfg), panelCurrentLimit.Config["state_topic"])
}
if assert.NotNil(t, panelStandby) {
assert.Equal(t, cfg.CommandTopic, panelStandby.Config["command_topic"])
assert.Equal(t, haStandbyStateTopic(cfg), panelStandby.Config["state_topic"])
}
if assert.NotNil(t, batteryVoltage) {
assert.Equal(t, cfg.Topic, batteryVoltage.Config["state_topic"])
}
}
func Test_panelStateCacheResolvePanelCommand(t *testing.T) {
cache := &panelStateCache{}
_, err := cache.resolvePanelCommand(writeCommand{
Kind: commandKindPanel,
CurrentLimitA: float64Ptr(12),
})
assert.Error(t, err)
cache.remember(writeCommand{
Kind: commandKindPanel,
HasSwitch: true,
SwitchState: mk2driver.PanelSwitchOn,
SwitchName: "on",
})
resolved, err := cache.resolvePanelCommand(writeCommand{
Kind: commandKindPanel,
CurrentLimitA: float64Ptr(10),
})
assert.NoError(t, err)
assert.True(t, resolved.HasSwitch)
assert.Equal(t, mk2driver.PanelSwitchOn, resolved.SwitchState)
assert.Equal(t, "on", resolved.SwitchName)
}
func float64Ptr(in float64) *float64 {
return &in
}
func Test_normalizeID(t *testing.T) {
assert.Equal(t, "victron_main_01", normalizeID("Victron Main #01"))
assert.Equal(t, "inverter-gui", normalizeID(" inverter-gui "))
assert.Equal(t, "", normalizeID(" "))
}

View File

@@ -36,7 +36,7 @@ import (
"net/http"
"time"
"github.com/diebietse/invertergui/mk2driver"
"invertergui/mk2driver"
"github.com/sirupsen/logrus"
)

View File

@@ -6,7 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/diebietse/invertergui/mk2driver"
"invertergui/mk2driver"
)
func TestServer(_ *testing.T) {

View File

@@ -31,7 +31,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package prometheus
import (
"github.com/diebietse/invertergui/mk2driver"
"invertergui/mk2driver"
"github.com/prometheus/client_golang/prometheus"
)

View File

@@ -98,6 +98,117 @@
<div class="alert alert-danger" role="alert" v-if="error.has_error">
{{ error.error_message }}
</div>
<div
class="alert"
v-if="control.message !== ''"
v-bind:class="[control.has_error ? 'alert-danger' : 'alert-success']"
>
{{ control.message }}
</div>
<div class="row">
<div class="col">
<div class="card mb-3">
<div class="card-body">
<h4 class="card-title">Remote Panel Control</h4>
<p class="text-muted mb-2">
Mode and current limit are applied together, equivalent to
<code>set_remote_panel_state</code>.
</p>
<p class="mb-1">
<strong>Current Mode:</strong>
{{ remoteModeLabel(state.remote_panel) }}
</p>
<p class="mb-1">
<strong>Current Limit:</strong>
{{ state.remote_panel.current_limit === null || state.remote_panel.current_limit === undefined ? 'Unknown' : state.remote_panel.current_limit + ' A' }}
</p>
<p class="mb-3">
<strong>Standby:</strong>
{{ remoteStandbyLabel(state.remote_panel) }}
</p>
<div class="row">
<div class="col-md-8">
<form v-on:submit.prevent="applyRemotePanelState">
<div class="form-row">
<div class="form-group col-md-6">
<label for="modeSelect">Remote Panel Mode</label>
<select
class="form-control"
id="modeSelect"
v-model="remote_form.mode"
v-bind:disabled="!state.remote_panel.writable || control.busy"
>
<option value="on">on</option>
<option value="off">off</option>
<option value="charger_only">charger_only</option>
<option value="inverter_only">inverter_only</option>
</select>
</div>
<div class="form-group col-md-6">
<label for="currentLimitInput">AC Input Current Limit (A)</label>
<input
id="currentLimitInput"
type="number"
min="0"
step="0.1"
class="form-control"
v-model="remote_form.current_limit"
placeholder="leave blank to keep current"
v-bind:disabled="!state.remote_panel.writable || control.busy"
/>
</div>
</div>
<button
type="submit"
class="btn btn-primary"
v-bind:disabled="!state.remote_panel.writable || control.busy"
>
Apply Mode + Current Limit
</button>
</form>
</div>
<div class="col-md-4">
<form v-on:submit.prevent="applyStandby">
<div class="form-group">
<div class="form-check mt-4">
<input
id="standbySwitch"
type="checkbox"
class="form-check-input"
v-model="remote_form.standby"
v-bind:disabled="!state.remote_panel.writable || control.busy"
/>
<label class="form-check-label" for="standbySwitch">
Prevent sleep while off
</label>
</div>
</div>
<button
type="submit"
class="btn btn-secondary"
v-bind:disabled="!state.remote_panel.writable || control.busy"
>
Apply Standby
</button>
</form>
</div>
</div>
<div class="mt-3 text-muted" v-if="state.remote_panel.last_updated">
Last update {{ state.remote_panel.last_updated }}
<span v-if="state.remote_panel.last_command">
({{ state.remote_panel.last_command }})
</span>
</div>
<div class="mt-2 text-danger" v-if="state.remote_panel.last_error">
{{ state.remote_panel.last_error }}
</div>
<div class="mt-2 text-warning" v-if="!state.remote_panel.writable">
Remote control is unavailable for this data source.
</div>
</div>
</div>
</div>
</div>
<div class="row">
<div class="col">
<hr />

View File

@@ -3,15 +3,20 @@ const timeoutMax = 30000;
const timeoutMin = 1000;
var timeout = timeoutMin;
function loadContent() {
app = new Vue({
el: "#app",
data: {
error: {
has_error: false,
error_message: ""
},
state: {
function defaultRemotePanelState() {
return {
writable: false,
mode: "unknown",
current_limit: null,
standby: null,
last_command: "",
last_error: "",
last_updated: ""
};
}
function defaultState() {
return {
output_current: null,
output_voltage: 0,
output_frequency: 0,
@@ -24,20 +29,194 @@ function loadContent() {
battery_voltage: 0,
battery_charge: 0,
battery_power: 0,
led_map: [
{ led_mains: "dot-off" },
{ led_absorb: "dot-off" },
{ led_bulk: "dot-off" },
{ led_float: "dot-off" },
{ led_inverter: "dot-off" },
{ led_overload: "dot-off" },
{ led_bat_low: "dot-off" },
{ led_over_temp: "dot-off" }
]
led_map: {
led_mains: "dot-off",
led_absorb: "dot-off",
led_bulk: "dot-off",
led_float: "dot-off",
led_inverter: "dot-off",
led_overload: "dot-off",
led_bat_low: "dot-off",
led_over_temp: "dot-off"
},
remote_panel: defaultRemotePanelState()
};
}
function loadContent() {
app = new Vue({
el: "#app",
data: {
error: {
has_error: false,
error_message: ""
},
control: {
busy: false,
has_error: false,
message: ""
},
remote_form: {
mode: "on",
current_limit: "",
standby: false
},
state: defaultState()
},
methods: {
syncRemoteFormFromState: function(remoteState) {
if (!remoteState) {
return;
}
if (remoteState.mode && remoteState.mode !== "unknown") {
this.remote_form.mode = remoteState.mode;
}
if (remoteState.current_limit === null || remoteState.current_limit === undefined) {
this.remote_form.current_limit = "";
} else {
this.remote_form.current_limit = String(remoteState.current_limit);
}
if (remoteState.standby === null || remoteState.standby === undefined) {
this.remote_form.standby = false;
} else {
this.remote_form.standby = !!remoteState.standby;
}
},
remoteModeLabel: function(remoteState) {
var mode = (remoteState && remoteState.mode) || "unknown";
if (mode === "charger_only") {
return "Charger Only";
}
if (mode === "inverter_only") {
return "Inverter Only";
}
if (mode === "on") {
return "On";
}
if (mode === "off") {
return "Off";
}
return "Unknown";
},
remoteStandbyLabel: function(remoteState) {
if (!remoteState || remoteState.standby === null || remoteState.standby === undefined) {
return "Unknown";
}
return remoteState.standby ? "Enabled" : "Disabled";
},
refreshRemoteState: function() {
var self = this;
fetch(getAPIURI("api/remote-panel/state"))
.then(function(resp) {
if (!resp.ok) {
throw new Error("Could not load remote panel state.");
}
return resp.json();
})
.then(function(payload) {
self.state.remote_panel = payload;
self.syncRemoteFormFromState(payload);
})
.catch(function(err) {
self.control.has_error = true;
self.control.message = err.message;
});
},
applyRemotePanelState: function() {
var self = this;
if (!self.state.remote_panel.writable) {
return;
}
var body = {
mode: self.remote_form.mode
};
if (self.remote_form.current_limit !== "") {
var parsed = parseFloat(self.remote_form.current_limit);
if (isNaN(parsed)) {
self.control.has_error = true;
self.control.message = "Current limit must be numeric.";
return;
}
body.current_limit = parsed;
}
self.control.busy = true;
self.control.has_error = false;
self.control.message = "";
fetch(getAPIURI("api/remote-panel/state"), {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify(body)
})
.then(function(resp) {
if (!resp.ok) {
return resp.text().then(function(text) {
throw new Error(text || "Failed to set remote panel mode/current limit.");
});
}
return resp.json();
})
.then(function(payload) {
self.state.remote_panel = payload;
self.syncRemoteFormFromState(payload);
self.control.has_error = false;
self.control.message = "Remote panel state updated.";
})
.catch(function(err) {
self.control.has_error = true;
self.control.message = err.message;
})
.finally(function() {
self.control.busy = false;
});
},
applyStandby: function() {
var self = this;
if (!self.state.remote_panel.writable) {
return;
}
self.control.busy = true;
self.control.has_error = false;
self.control.message = "";
fetch(getAPIURI("api/remote-panel/standby"), {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify({
standby: !!self.remote_form.standby
})
})
.then(function(resp) {
if (!resp.ok) {
return resp.text().then(function(text) {
throw new Error(text || "Failed to set standby mode.");
});
}
return resp.json();
})
.then(function(payload) {
self.state.remote_panel = payload;
self.syncRemoteFormFromState(payload);
self.control.has_error = false;
self.control.message = "Standby mode updated.";
})
.catch(function(err) {
self.control.has_error = true;
self.control.message = err.message;
})
.finally(function() {
self.control.busy = false;
});
}
}
});
app.refreshRemoteState();
connect();
}
@@ -61,7 +240,7 @@ function connect() {
}
};
conn.onopen = function(evt) {
conn.onopen = function() {
timeout = timeoutMin;
app.error.has_error = false;
};
@@ -69,6 +248,9 @@ function connect() {
conn.onmessage = function(evt) {
var update = JSON.parse(evt.data);
app.state = update;
if (!app.control.busy) {
app.syncRemoteFormFromState(update.remote_panel);
}
};
} else {
app.error.has_error = true;
@@ -88,3 +270,11 @@ function getURI() {
new_uri += loc.pathname + "ws";
return new_uri;
}
function getAPIURI(path) {
var base = window.location.pathname;
if (base.slice(-1) !== "/") {
base += "/";
}
return base + path.replace(/^\/+/, "");
}

View File

@@ -31,13 +31,15 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package webui
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/diebietse/invertergui/mk2driver"
"github.com/diebietse/invertergui/websocket"
"invertergui/mk2driver"
"invertergui/websocket"
"github.com/sirupsen/logrus"
)
@@ -51,25 +53,53 @@ const (
BlinkGreen = "blink-green"
)
const (
modeChargerOnly = "charger_only"
modeInverterOnly = "inverter_only"
modeOn = "on"
modeOff = "off"
modeUnknown = "unknown"
)
type WebGui struct {
mk2driver.Mk2
writer mk2driver.SettingsWriter
stopChan chan struct{}
wg sync.WaitGroup
hub *websocket.Hub
stateMu sync.RWMutex
latest *templateInput
remote remotePanelState
}
func NewWebGui(source mk2driver.Mk2) *WebGui {
func NewWebGui(source mk2driver.Mk2, writer mk2driver.SettingsWriter) *WebGui {
w := &WebGui{
stopChan: make(chan struct{}),
Mk2: source,
writer: writer,
hub: websocket.NewHub(),
remote: remotePanelState{
Writable: writer != nil,
Mode: modeUnknown,
},
}
w.wg.Add(1)
go w.dataPoll()
return w
}
type remotePanelState struct {
Writable bool `json:"writable"`
Mode string `json:"mode"`
CurrentLimit *float64 `json:"current_limit,omitempty"`
Standby *bool `json:"standby,omitempty"`
LastCommand string `json:"last_command,omitempty"`
LastError string `json:"last_error,omitempty"`
LastUpdated string `json:"last_updated,omitempty"`
}
type templateInput struct {
Error []error `json:"errors"`
@@ -94,12 +124,125 @@ type templateInput struct {
OutFreq string `json:"output_frequency"`
LedMap map[string]string `json:"led_map"`
RemotePanel remotePanelState `json:"remote_panel"`
}
type setRemotePanelStateRequest struct {
Mode string `json:"mode"`
CurrentLimit *float64 `json:"current_limit"`
}
type setRemotePanelStandbyRequest struct {
Standby bool `json:"standby"`
}
func (w *WebGui) ServeHub(rw http.ResponseWriter, r *http.Request) {
w.hub.ServeHTTP(rw, r)
}
func (w *WebGui) ServeRemotePanelState(rw http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
writeJSON(rw, http.StatusOK, w.getRemotePanelState())
case http.MethodPost:
w.handleSetRemotePanelState(rw, r)
default:
http.Error(rw, "method not allowed", http.StatusMethodNotAllowed)
}
}
func (w *WebGui) ServeRemotePanelStandby(rw http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
writeJSON(rw, http.StatusOK, w.getRemotePanelState())
case http.MethodPost:
w.handleSetRemotePanelStandby(rw, r)
default:
http.Error(rw, "method not allowed", http.StatusMethodNotAllowed)
}
}
func (w *WebGui) handleSetRemotePanelState(rw http.ResponseWriter, r *http.Request) {
if w.writer == nil {
http.Error(rw, "remote control is not supported by this data source", http.StatusNotImplemented)
return
}
req := setRemotePanelStateRequest{}
if err := decodeJSONBody(r, &req); err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
switchState, normalizedMode, err := parsePanelMode(req.Mode)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if err := w.writer.SetPanelState(switchState, req.CurrentLimit); err != nil {
w.updateRemotePanelState(func(state *remotePanelState) {
state.LastCommand = "set_remote_panel_state"
state.LastError = err.Error()
})
http.Error(rw, err.Error(), http.StatusBadGateway)
return
}
w.updateRemotePanelState(func(state *remotePanelState) {
state.Mode = normalizedMode
state.CurrentLimit = copyFloat64Ptr(req.CurrentLimit)
state.LastCommand = "set_remote_panel_state"
state.LastError = ""
})
writeJSON(rw, http.StatusOK, w.getRemotePanelState())
}
func (w *WebGui) handleSetRemotePanelStandby(rw http.ResponseWriter, r *http.Request) {
if w.writer == nil {
http.Error(rw, "remote control is not supported by this data source", http.StatusNotImplemented)
return
}
req := setRemotePanelStandbyRequest{}
if err := decodeJSONBody(r, &req); err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)
return
}
if err := w.writer.SetStandby(req.Standby); err != nil {
w.updateRemotePanelState(func(state *remotePanelState) {
state.LastCommand = "set_remote_panel_standby"
state.LastError = err.Error()
})
http.Error(rw, err.Error(), http.StatusBadGateway)
return
}
w.updateRemotePanelState(func(state *remotePanelState) {
state.Standby = copyBoolPtr(&req.Standby)
state.LastCommand = "set_remote_panel_standby"
state.LastError = ""
})
writeJSON(rw, http.StatusOK, w.getRemotePanelState())
}
func parsePanelMode(raw string) (mk2driver.PanelSwitchState, string, error) {
switch strings.TrimSpace(strings.ToLower(raw)) {
case modeChargerOnly:
return mk2driver.PanelSwitchChargerOnly, modeChargerOnly, nil
case modeInverterOnly:
return mk2driver.PanelSwitchInverterOnly, modeInverterOnly, nil
case modeOn:
return mk2driver.PanelSwitchOn, modeOn, nil
case modeOff:
return mk2driver.PanelSwitchOff, modeOff, nil
default:
return 0, "", fmt.Errorf("unsupported panel mode %q", raw)
}
}
func ledName(led mk2driver.Led) string {
name, ok := mk2driver.LedNames[led]
if !ok {
@@ -162,15 +305,17 @@ func (w *WebGui) Stop() {
w.wg.Wait()
}
// dataPoll waits for data from the w.poller channel. It will send its currently stored status
// to respChan if anything reads from it.
func (w *WebGui) dataPoll() {
for {
select {
case s := <-w.C():
if s.Valid {
err := w.hub.Broadcast(buildTemplateInput(s))
if err != nil {
payload := buildTemplateInput(s)
w.stateMu.Lock()
payload.RemotePanel = w.remote
w.latest = payload
w.stateMu.Unlock()
if err := w.hub.Broadcast(payload); err != nil {
log.Errorf("Could not send update to clients: %v", err)
}
}
@@ -180,3 +325,93 @@ func (w *WebGui) dataPoll() {
}
}
}
func (w *WebGui) getRemotePanelState() remotePanelState {
w.stateMu.RLock()
defer w.stateMu.RUnlock()
return copyRemotePanelState(w.remote)
}
func (w *WebGui) updateRemotePanelState(update func(state *remotePanelState)) {
w.stateMu.Lock()
update(&w.remote)
w.remote.LastUpdated = time.Now().UTC().Format(time.RFC3339)
snapshot := w.snapshotLocked()
w.stateMu.Unlock()
if snapshot != nil {
if err := w.hub.Broadcast(snapshot); err != nil {
log.Errorf("Could not send control update to clients: %v", err)
}
}
}
func (w *WebGui) snapshotLocked() *templateInput {
if w.latest == nil {
return nil
}
snapshot := cloneTemplateInput(w.latest)
snapshot.RemotePanel = copyRemotePanelState(w.remote)
return snapshot
}
func cloneTemplateInput(in *templateInput) *templateInput {
if in == nil {
return nil
}
out := *in
if in.Error != nil {
out.Error = append([]error(nil), in.Error...)
}
if in.LedMap != nil {
out.LedMap = make(map[string]string, len(in.LedMap))
for k, v := range in.LedMap {
out.LedMap[k] = v
}
}
out.RemotePanel = copyRemotePanelState(in.RemotePanel)
return &out
}
func copyRemotePanelState(in remotePanelState) remotePanelState {
in.CurrentLimit = copyFloat64Ptr(in.CurrentLimit)
in.Standby = copyBoolPtr(in.Standby)
return in
}
func copyFloat64Ptr(in *float64) *float64 {
if in == nil {
return nil
}
value := *in
return &value
}
func copyBoolPtr(in *bool) *bool {
if in == nil {
return nil
}
value := *in
return &value
}
func decodeJSONBody(r *http.Request, destination any) error {
defer r.Body.Close()
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
if err := decoder.Decode(destination); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
return nil
}
func writeJSON(rw http.ResponseWriter, statusCode int, payload any) {
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(statusCode)
if err := json.NewEncoder(rw).Encode(payload); err != nil {
log.Errorf("Could not encode webui API response: %v", err)
}
}

View File

@@ -36,7 +36,7 @@ import (
"testing"
"time"
"github.com/diebietse/invertergui/mk2driver"
"invertergui/mk2driver"
)
type templateTest struct {
@@ -91,3 +91,53 @@ func TestTemplateInput(t *testing.T) {
}
}
}
func TestParsePanelMode(t *testing.T) {
tests := []struct {
name string
input string
want mk2driver.PanelSwitchState
wantRaw string
wantErr bool
}{
{
name: "on",
input: "on",
want: mk2driver.PanelSwitchOn,
wantRaw: "on",
},
{
name: "charger_only",
input: "charger_only",
want: mk2driver.PanelSwitchChargerOnly,
wantRaw: "charger_only",
},
{
name: "invalid",
input: "banana",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, gotRaw, err := parsePanelMode(tt.input)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.want {
t.Fatalf("got switch %d, want %d", got, tt.want)
}
if gotRaw != tt.wantRaw {
t.Fatalf("got mode %q, want %q", gotRaw, tt.wantRaw)
}
})
}
}

View File

@@ -70,3 +70,5 @@ benchstat <(go test -benchtime 500ms -count 15 -bench 'Sum64$')
- [VictoriaMetrics](https://github.com/VictoriaMetrics/VictoriaMetrics)
- [FreeCache](https://github.com/coocood/freecache)
- [FastCache](https://github.com/VictoriaMetrics/fastcache)
- [Ristretto](https://github.com/dgraph-io/ristretto)
- [Badger](https://github.com/dgraph-io/badger)

View File

@@ -19,10 +19,13 @@ const (
// Store the primes in an array as well.
//
// The consts are used when possible in Go code to avoid MOVs but we need a
// contiguous array of the assembly code.
// contiguous array for the assembly code.
var primes = [...]uint64{prime1, prime2, prime3, prime4, prime5}
// Digest implements hash.Hash64.
//
// Note that a zero-valued Digest is not ready to receive writes.
// Call Reset or create a Digest using New before calling other methods.
type Digest struct {
v1 uint64
v2 uint64
@@ -33,19 +36,31 @@ type Digest struct {
n int // how much of mem is used
}
// New creates a new Digest that computes the 64-bit xxHash algorithm.
// New creates a new Digest with a zero seed.
func New() *Digest {
return NewWithSeed(0)
}
// NewWithSeed creates a new Digest with the given seed.
func NewWithSeed(seed uint64) *Digest {
var d Digest
d.Reset()
d.ResetWithSeed(seed)
return &d
}
// Reset clears the Digest's state so that it can be reused.
// It uses a seed value of zero.
func (d *Digest) Reset() {
d.v1 = primes[0] + prime2
d.v2 = prime2
d.v3 = 0
d.v4 = -primes[0]
d.ResetWithSeed(0)
}
// ResetWithSeed clears the Digest's state so that it can be reused.
// It uses the given seed to initialize the state.
func (d *Digest) ResetWithSeed(seed uint64) {
d.v1 = seed + prime1 + prime2
d.v2 = seed + prime2
d.v3 = seed
d.v4 = seed - prime1
d.total = 0
d.n = 0
}

View File

@@ -6,7 +6,7 @@
package xxhash
// Sum64 computes the 64-bit xxHash digest of b.
// Sum64 computes the 64-bit xxHash digest of b with a zero seed.
//
//go:noescape
func Sum64(b []byte) uint64

View File

@@ -3,7 +3,7 @@
package xxhash
// Sum64 computes the 64-bit xxHash digest of b.
// Sum64 computes the 64-bit xxHash digest of b with a zero seed.
func Sum64(b []byte) uint64 {
// A simpler version would be
// d := New()

View File

@@ -5,7 +5,7 @@
package xxhash
// Sum64String computes the 64-bit xxHash digest of s.
// Sum64String computes the 64-bit xxHash digest of s with a zero seed.
func Sum64String(s string) uint64 {
return Sum64([]byte(s))
}

View File

@@ -33,7 +33,7 @@ import (
//
// See https://github.com/golang/go/issues/42739 for discussion.
// Sum64String computes the 64-bit xxHash digest of s.
// Sum64String computes the 64-bit xxHash digest of s with a zero seed.
// It may be faster than Sum64([]byte(s)) by avoiding a copy.
func Sum64String(s string) uint64 {
b := *(*[]byte)(unsafe.Pointer(&sliceHeader{s, len(s)}))

View File

@@ -258,12 +258,15 @@ func (c *client) Connect() Token {
return
}
var attemptCount int
RETRYCONN:
var conn net.Conn
var rc byte
var err error
conn, rc, t.sessionPresent, err = c.attemptConnection()
conn, rc, t.sessionPresent, err = c.attemptConnection(false, attemptCount)
if err != nil {
attemptCount++
if c.options.ConnectRetry {
DEBUG.Println(CLI, "Connect failed, sleeping for", int(c.options.ConnectRetryInterval.Seconds()), "seconds and will then retry, error:", err.Error())
time.Sleep(c.options.ConnectRetryInterval)
@@ -315,15 +318,17 @@ func (c *client) reconnect(connectionUp connCompletedFn) {
DEBUG.Println(CLI, "Detect continual connection lost after reconnect, slept for", int(slp.Seconds()), "seconds")
}
var attemptCount int
for {
if nil != c.options.OnReconnecting {
c.options.OnReconnecting(c, &c.options)
}
var err error
conn, _, _, err = c.attemptConnection()
conn, _, _, err = c.attemptConnection(true, attemptCount)
if err == nil {
break
}
attemptCount++
sleep, _ := c.backoff.sleepWithBackoff("attemptReconnection", initSleep, c.options.MaxReconnectInterval, c.options.ConnectTimeout, false)
DEBUG.Println(CLI, "Reconnect failed, slept for", int(sleep.Seconds()), "seconds:", err)
@@ -351,7 +356,7 @@ func (c *client) reconnect(connectionUp connCompletedFn) {
// byte - Return code (packets.Accepted indicates a successful connection).
// bool - SessionPresent flag from the connect ack (only valid if packets.Accepted)
// err - Error (err != nil guarantees that conn has been set to active connection).
func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
func (c *client) attemptConnection(isReconnect bool, attempt int) (net.Conn, byte, bool, error) {
protocolVersion := c.options.ProtocolVersion
var (
sessionPresent bool
@@ -360,6 +365,10 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
rc byte
)
if c.options.OnConnectionNotification != nil {
c.options.OnConnectionNotification(c, ConnectionNotificationConnecting{isReconnect, attempt})
}
c.optionsMu.Lock() // Protect c.options.Servers so that servers can be added in test cases
brokers := c.options.Servers
c.optionsMu.Unlock()
@@ -372,6 +381,9 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
DEBUG.Println(CLI, "using custom onConnectAttempt handler...")
tlsCfg = c.options.OnConnectAttempt(broker, c.options.TLSConfig)
}
if c.options.OnConnectionNotification != nil {
c.options.OnConnectionNotification(c, ConnectionNotificationBroker{broker})
}
connDeadline := time.Now().Add(c.options.ConnectTimeout) // Time by which connection must be established
dialer := c.options.Dialer
if dialer == nil { //
@@ -388,6 +400,9 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
ERROR.Println(CLI, err.Error())
WARN.Println(CLI, "failed to connect to broker, trying next")
rc = packets.ErrNetworkError
if c.options.OnConnectionNotification != nil {
c.options.OnConnectionNotification(c, ConnectionNotificationBrokerFailed{broker, err})
}
continue
}
DEBUG.Println(CLI, "socket connected to broker")
@@ -427,9 +442,12 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
if rc != packets.ErrNetworkError { // mqtt error
err = packets.ConnErrors[rc]
} else { // network error (if this occurred in ConnectMQTT then err will be nil)
err = fmt.Errorf("%s : %s", packets.ConnErrors[rc], err)
err = fmt.Errorf("%w : %w", packets.ConnErrors[rc], err)
}
}
if err != nil && c.options.OnConnectionNotification != nil {
c.options.OnConnectionNotification(c, ConnectionNotificationFailed{err})
}
return conn, rc, sessionPresent, err
}
@@ -564,6 +582,9 @@ func (c *client) internalConnLost(whyConnLost error) {
if c.options.OnConnectionLost != nil {
go c.options.OnConnectionLost(c, whyConnLost)
}
if c.options.OnConnectionNotification != nil {
go c.options.OnConnectionNotification(c, ConnectionNotificationLost{whyConnLost})
}
DEBUG.Println(CLI, "internalConnLost complete")
}()
}
@@ -601,21 +622,21 @@ func (c *client) startCommsWorkers(conn net.Conn, connectionUp connCompletedFn,
c.workers.Add(1) // Done will be called when ackOut is closed
ackOut := c.msgRouter.matchAndDispatch(incomingPubChan, c.options.Order, c)
// The connection is now ready for use (we spin up a few go routines below). It is possible that
// Disconnect has been called in the interim...
// The connection is now ready for use (we spin up a few go routines below).
// It is possible that Disconnect has been called in the interim...
// issue 675we will allow the connection to complete before the Disconnect is allowed to proceed
// as if a Disconnect event occurred immediately after connectionUp(true) completed.
if err := connectionUp(true); err != nil {
DEBUG.Println(CLI, err)
close(c.stop) // Tidy up anything we have already started
close(incomingPubChan)
c.workers.Wait()
c.conn.Close()
c.conn = nil
return false
ERROR.Println(CLI, err)
}
DEBUG.Println(CLI, "client is connected/reconnected")
if c.options.OnConnect != nil {
go c.options.OnConnect(c)
}
if c.options.OnConnectionNotification != nil {
go c.options.OnConnectionNotification(c, ConnectionNotificationConnected{})
}
// c.oboundP and c.obound need to stay active for the life of the client because, depending upon the options,
// messages may be published while the client is disconnected (they will block unless in a goroutine). However
@@ -799,9 +820,13 @@ func (c *client) Publish(topic string, qos byte, retained bool, payload interfac
if publishWaitTimeout == 0 {
publishWaitTimeout = time.Second * 30
}
t := time.NewTimer(publishWaitTimeout)
defer t.Stop()
select {
case c.obound <- &PacketAndToken{p: pub, t: token}:
case <-time.After(publishWaitTimeout):
case <-t.C:
token.setError(errors.New("publish was broken by timeout"))
}
}

View File

@@ -19,7 +19,7 @@
package mqtt
import (
"io/ioutil"
"io/fs"
"os"
"path"
"sort"
@@ -159,15 +159,20 @@ func (store *FileStore) Reset() {
func (store *FileStore) all() []string {
var err error
var keys []string
var files fileInfos
if !store.opened {
ERROR.Println(STR, "trying to use file store, but not open")
return nil
}
files, err = ioutil.ReadDir(store.directory)
entries, err := os.ReadDir(store.directory)
chkerr(err)
files := make(fileInfos, 0, len(entries))
for _, entry := range entries {
info, err := entry.Info()
chkerr(err)
files = append(files, info)
}
sort.Sort(files)
for _, f := range files {
DEBUG.Println(STR, "file in All():", f.Name())
@@ -246,7 +251,7 @@ func exists(file string) bool {
return true
}
type fileInfos []os.FileInfo
type fileInfos []fs.FileInfo
func (f fileInfos) Len() int {
return len(f)

View File

@@ -444,24 +444,23 @@ func startComms(conn net.Conn, // Network connection (must be active)
}
// ackFunc acknowledges a packet
// WARNING the function returned must not be called if the comms routine is shutting down or not running
// (it needs outgoing comms in order to send the acknowledgement). Currently this is only called from
// matchAndDispatch which will be shutdown before the comms are
func ackFunc(oboundP chan *PacketAndToken, persist Store, packet *packets.PublishPacket) func() {
// WARNING sendAck may be called at any time (even after the connection is dead). At the time of writing ACK sent after
// connection loss will be dropped (this is not ideal)
func ackFunc(sendAck func(*PacketAndToken), persist Store, packet *packets.PublishPacket) func() {
return func() {
switch packet.Qos {
case 2:
pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
pr.MessageID = packet.MessageID
DEBUG.Println(NET, "putting pubrec msg on obound")
oboundP <- &PacketAndToken{p: pr, t: nil}
sendAck(&PacketAndToken{p: pr, t: nil})
DEBUG.Println(NET, "done putting pubrec msg on obound")
case 1:
pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
pa.MessageID = packet.MessageID
DEBUG.Println(NET, "putting puback msg on obound")
persistOutbound(persist, pa)
oboundP <- &PacketAndToken{p: pa, t: nil}
persistOutbound(persist, pa) // May fail if store has been closed
sendAck(&PacketAndToken{p: pa, t: nil})
DEBUG.Println(NET, "done putting puback msg on obound")
case 0:
// do nothing, since there is no need to send an ack packet back

View File

@@ -50,16 +50,7 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
conn, err := NewWebsocket(dialURI.String(), tlsc, timeout, headers, websocketOptions)
return conn, err
case "mqtt", "tcp":
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := dialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
}
return conn, nil
}
proxyDialer := proxy.FromEnvironment()
proxyDialer := proxy.FromEnvironmentUsing(dialer)
conn, err := proxyDialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err

View File

@@ -62,6 +62,9 @@ type ConnectionAttemptHandler func(broker *url.URL, tlsCfg *tls.Config) *tls.Con
// Does not carry out any MQTT specific handshakes.
type OpenConnectionFunc func(uri *url.URL, options ClientOptions) (net.Conn, error)
// ConnectionNotificationHandler is invoked for any type of connection event.
type ConnectionNotificationHandler func(Client, ConnectionNotification)
// ClientOptions contains configurable options for an Client. Note that these should be set using the
// relevant methods (e.g. AddBroker) rather than directly. See those functions for information on usage.
// WARNING: Create the below using NewClientOptions unless you have a compelling reason not to. It is easy
@@ -96,6 +99,7 @@ type ClientOptions struct {
OnConnectionLost ConnectionLostHandler
OnReconnecting ReconnectHandler
OnConnectAttempt ConnectionAttemptHandler
OnConnectionNotification ConnectionNotificationHandler
WriteTimeout time.Duration
MessageChannelDepth uint
ResumeSubs bool
@@ -109,6 +113,7 @@ type ClientOptions struct {
// NewClientOptions will create a new ClientClientOptions type with some
// default values.
//
// Port: 1883
// CleanSession: True
// Order: True (note: it is recommended that this be set to FALSE unless order is important)
@@ -142,6 +147,7 @@ func NewClientOptions() *ClientOptions {
OnConnect: nil,
OnConnectionLost: DefaultConnectionLostHandler,
OnConnectAttempt: nil,
OnConnectionNotification: nil,
WriteTimeout: 0, // 0 represents timeout disabled
ResumeSubs: false,
HTTPHeaders: make(map[string][]string),
@@ -355,6 +361,13 @@ func (o *ClientOptions) SetConnectionAttemptHandler(onConnectAttempt ConnectionA
return o
}
// SetConnectionNotificationHandler sets the ConnectionNotificationHandler callback to receive all types of connection
// events.
func (o *ClientOptions) SetConnectionNotificationHandler(onConnectionNotification ConnectionNotificationHandler) *ClientOptions {
o.OnConnectionNotification = onConnectionNotification
return o
}
// SetWriteTimeout puts a limit on how long a mqtt publish should block until it unblocks with a
// timeout error. A duration of 0 never times out. Default never times out
func (o *ClientOptions) SetWriteTimeout(t time.Duration) *ClientOptions {
@@ -450,6 +463,7 @@ func (o *ClientOptions) SetCustomOpenConnectionFn(customOpenConnectionFn OpenCon
}
// SetAutoAckDisabled enables or disables the Automated Acking of Messages received by the handler.
//
// By default it is set to false. Setting it to true will disable the auto-ack globally.
func (o *ClientOptions) SetAutoAckDisabled(autoAckDisabled bool) *ClientOptions {
o.AutoAckDisabled = autoAckDisabled

View File

@@ -30,6 +30,21 @@ type ClientOptionsReader struct {
options *ClientOptions
}
// NewOptionsReader creates a ClientOptionsReader, this should only be used for mocking purposes.
//
// An example implementation:
//
// func (c *mqttClientMock) OptionsReader() mqtt.ClientOptionsReader {
// opts := mqtt.NewClientOptions()
// opts.UserName = "TestUserName"
// return mqtt.NewOptionsReader(opts)
// }
func NewOptionsReader(o *ClientOptions) ClientOptionsReader {
return ClientOptionsReader{
options: o,
}
}
// Servers returns a slice of the servers defined in the clientoptions
func (r *ClientOptionsReader) Servers() []*url.URL {
s := make([]*url.URL, len(r.options.Servers))

View File

@@ -330,6 +330,11 @@ func decodeBytes(b io.Reader) ([]byte, error) {
}
func encodeBytes(field []byte) []byte {
// Attempting to encode more than 65,535 bytes would lead to an unexpected 16-bit length and extra data written
// (which would be parsed as later parts of the message). The safest option is to truncate.
if len(field) > 65535 {
field = field[0:65535]
}
fieldLength := make([]byte, 2)
binary.BigEndian.PutUint16(fieldLength, uint16(len(field)))
return append(fieldLength, field...)

View File

@@ -38,7 +38,7 @@ func keepalive(c *client, conn io.Writer) {
if c.options.KeepAlive > 10 {
checkInterval = 5 * time.Second
} else {
checkInterval = time.Duration(c.options.KeepAlive) * time.Second / 2
checkInterval = time.Duration(c.options.KeepAlive) * time.Second / 4
}
intervalTicker := time.NewTicker(checkInterval)

View File

@@ -136,60 +136,41 @@ func (r *router) setDefaultHandler(handler MessageHandler) {
// associated callback (or the defaultHandler, if one exists and no other route matched). If
// anything is sent down the stop channel the function will end.
func (r *router) matchAndDispatch(messages <-chan *packets.PublishPacket, order bool, client *client) <-chan *PacketAndToken {
var wg sync.WaitGroup
ackOutChan := make(chan *PacketAndToken) // Channel returned to caller; closed when messages channel closed
var ackInChan chan *PacketAndToken // ACKs generated by ackFunc get put onto this channel
ackChan := make(chan *PacketAndToken) // Channel returned to caller; closed when goroutine terminates
stopAckCopy := make(chan struct{}) // Closure requests stop of go routine copying ackInChan to ackOutChan
ackCopyStopped := make(chan struct{}) // Closure indicates that it is safe to close ackOutChan
goRoutinesDone := make(chan struct{}) // closed on wg.Done()
if order {
ackInChan = ackOutChan // When order = true no go routines are used so safe to use one channel and close when done
// In some cases message acknowledgments may come through after shutdown (connection is down etc). Where this is the
// case we need to accept any such requests and then ignore them. Note that this is not a perfect solution, if we
// have reconnected, and the session is still live, then the Ack really should be sent (see Issus #726)
var ackMutex sync.RWMutex
sendAckChan := ackChan // This will be set to nil before ackChan is closed
sendAck := func(ack *PacketAndToken) {
ackMutex.RLock()
defer ackMutex.RUnlock()
if sendAckChan != nil {
sendAckChan <- ack
} else {
// When order = false ACK messages are sent in go routines so ackInChan cannot be closed until all goroutines done
ackInChan = make(chan *PacketAndToken)
go func() { // go routine to copy from ackInChan to ackOutChan until stopped
for {
select {
case a := <-ackInChan:
ackOutChan <- a
case <-stopAckCopy:
close(ackCopyStopped) // Signal main go routine that it is safe to close ackOutChan
for {
select {
case <-ackInChan: // drain ackInChan to ensure all goRoutines can complete cleanly (ACK dropped)
DEBUG.Println(ROU, "matchAndDispatch received acknowledgment after processing stopped (ACK dropped).")
case <-goRoutinesDone:
close(ackInChan) // Nothing further should be sent (a panic is probably better than silent failure)
DEBUG.Println(ROU, "matchAndDispatch order=false copy goroutine exiting.")
return
}
}
}
}
}()
}
go func() { // Main go routine handling inbound messages
var handlers []MessageHandler
for message := range messages {
// DEBUG.Println(ROU, "matchAndDispatch received message")
sent := false
r.RLock()
m := messageFromPublish(message, ackFunc(ackInChan, client.persist, message))
var handlers []MessageHandler
m := messageFromPublish(message, ackFunc(sendAck, client.persist, message))
for e := r.routes.Front(); e != nil; e = e.Next() {
if e.Value.(*route).match(message.TopicName) {
if order {
handlers = append(handlers, e.Value.(*route).callback)
} else {
hd := e.Value.(*route).callback
wg.Add(1)
go func() {
hd(client, m)
if !client.options.AutoAckDisabled {
m.Ack()
}
wg.Done()
}()
}
sent = true
@@ -200,13 +181,11 @@ func (r *router) matchAndDispatch(messages <-chan *packets.PublishPacket, order
if order {
handlers = append(handlers, r.defaultHandler)
} else {
wg.Add(1)
go func() {
r.defaultHandler(client, m)
if !client.options.AutoAckDisabled {
m.Ack()
}
wg.Done()
}()
}
} else {
@@ -214,26 +193,22 @@ func (r *router) matchAndDispatch(messages <-chan *packets.PublishPacket, order
}
}
r.RUnlock()
if order {
for _, handler := range handlers {
handler(client, m)
if !client.options.AutoAckDisabled {
m.Ack()
}
}
handlers = handlers[:0]
}
// DEBUG.Println(ROU, "matchAndDispatch handled message")
}
if order {
close(ackOutChan)
} else { // Ensure that nothing further will be written to ackOutChan before closing it
close(stopAckCopy)
<-ackCopyStopped
close(ackOutChan)
go func() {
wg.Wait() // Note: If this remains running then the user has handlers that are not returning
close(goRoutinesDone)
}()
}
ackMutex.Lock()
sendAckChan = nil
ackMutex.Unlock()
close(ackChan) // as sendAckChan is now nil nothing further will be sent on this
DEBUG.Println(ROU, "matchAndDispatch exiting")
}()
return ackOutChan
return ackChan
}

View File

@@ -17,6 +17,7 @@
package mqtt
import (
"errors"
"sync"
"time"
@@ -202,3 +203,20 @@ type UnsubscribeToken struct {
type DisconnectToken struct {
baseToken
}
// TimedOut is the error returned by WaitTimeout when the timeout expires
var TimedOut = errors.New("context canceled")
// WaitTokenTimeout is a utility function used to simplify the use of token.WaitTimeout
// token.WaitTimeout may return `false` due to time out but t.Error() still results
// in nil.
// `if t := client.X(); t.WaitTimeout(time.Second) && t.Error() != nil {` may evaluate
// to false even if the operation fails.
// It is important to note that if TimedOut is returned, then the operation may still be running
// and could eventually complete successfully.
func WaitTokenTimeout(t Token, d time.Duration) error {
if !t.WaitTimeout(d) {
return TimedOut
}
return t.Error()
}

View File

@@ -1,20 +0,0 @@
; https://editorconfig.org/
root = true
[*]
insert_final_newline = true
charset = utf-8
trim_trailing_whitespace = true
indent_style = space
indent_size = 2
[{Makefile,go.mod,go.sum,*.go,.gitmodules}]
indent_style = tab
indent_size = 4
[*.md]
indent_size = 4
trim_trailing_whitespace = false
eclint_indent_style = unset

View File

@@ -1 +1,25 @@
coverage.coverprofile
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
.idea/
*.iml

View File

@@ -1,3 +0,0 @@
run:
skip-dirs:
- examples/*.go

View File

@@ -1,27 +1,22 @@
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,34 +0,0 @@
GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '')
GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest
GO_SEC=$(shell which gosec 2> /dev/null || echo '')
GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest
GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '')
GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest
.PHONY: golangci-lint
golangci-lint:
$(if $(GO_LINT), ,go install $(GO_LINT_URI))
@echo "##### Running golangci-lint"
golangci-lint run -v
.PHONY: gosec
gosec:
$(if $(GO_SEC), ,go install $(GO_SEC_URI))
@echo "##### Running gosec"
gosec -exclude-dir examples ./...
.PHONY: govulncheck
govulncheck:
$(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI))
@echo "##### Running govulncheck"
govulncheck ./...
.PHONY: verify
verify: golangci-lint gosec govulncheck
.PHONY: test
test:
@echo "##### Running tests"
go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./...

View File

@@ -1,13 +1,10 @@
# gorilla/websocket
# Gorilla WebSocket
![testing](https://github.com/gorilla/websocket/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/websocket/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/websocket)
[![godoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
[![sourcegraph](https://sourcegraph.com/github.com/gorilla/websocket/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/websocket?badge)
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket)
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
![Gorilla Logo](https://github.com/gorilla/.github/assets/53367916/d92caabf-98e0-473e-bfbf-ab554ba435e5)
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
### Documentation
@@ -17,7 +14,6 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
* [Write buffer pool example](https://github.com/gorilla/websocket/tree/master/examples/bufferpool)
### Status
@@ -34,3 +30,4 @@ package API is stable.
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).

View File

@@ -11,16 +11,13 @@ import (
"errors"
"fmt"
"io"
"log"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
"time"
"golang.org/x/net/proxy"
)
// ErrBadHandshake is returned when the server response to opening handshake is
@@ -228,7 +225,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
//#nosec G101 (CWE-798): Potential HTTP request smuggling via parameter pollution
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
@@ -294,9 +290,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
err = c.SetDeadline(deadline)
if err != nil {
if err := c.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
c.Close()
return nil, err
}
return c, nil
@@ -310,7 +304,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
if err != nil {
return nil, nil, err
}
@@ -336,9 +330,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer func() {
if netConn != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
netConn.Close()
}
}()
@@ -408,7 +400,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake
}
@@ -426,19 +418,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break
}
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
if err := netConn.SetDeadline(time.Time{}); err != nil {
return nil, nil, err
}
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil
}
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{MinVersion: tls.VersionTLS12}
return &tls.Config{}
}
return cfg.Clone()
}

View File

@@ -8,7 +8,6 @@ import (
"compress/flate"
"errors"
"io"
"log"
"strings"
"sync"
)
@@ -34,9 +33,7 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
if err := fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil); err != nil {
panic(err)
}
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr}
}
@@ -135,9 +132,7 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
if err := r.Close(); err != nil {
log.Printf("websocket: flateReadWrapper.Close() returned error: %v", err)
}
r.Close()
}
return n, err
}

View File

@@ -6,11 +6,11 @@ package websocket
import (
"bufio"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"log"
"io/ioutil"
"math/rand"
"net"
"strconv"
"strings"
@@ -181,20 +181,13 @@ var (
errInvalidControlFrame = errors.New("websocket: invalid control frame")
)
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
// reproducible results.
var maskRand = rand.Reader
// newMaskKey returns a new 32 bit value for masking client frames.
func newMaskKey() [4]byte {
var k [4]byte
_, _ = io.ReadFull(maskRand, k[:])
return k
n := rand.Uint32()
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
}
func hideTempErr(err error) error {
if e, ok := err.(net.Error); ok {
if e, ok := err.(net.Error); ok && e.Temporary() {
err = &netError{msg: e.Error(), timeout: e.Timeout()}
}
return err
@@ -379,9 +372,7 @@ func (c *Conn) read(n int) ([]byte, error) {
if err == io.EOF {
err = errUnexpectedEOF
}
if _, err := c.br.Discard(len(p)); err != nil {
return p, err
}
c.br.Discard(len(p))
return p, err
}
@@ -396,9 +387,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return err
}
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
c.conn.SetWriteDeadline(deadline)
if len(buf1) == 0 {
_, err = c.conn.Write(buf0)
} else {
@@ -408,7 +397,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return c.writeFatal(err)
}
if frameType == CloseMessage {
_ = c.writeFatal(ErrCloseSent)
c.writeFatal(ErrCloseSent)
}
return nil
}
@@ -449,7 +438,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
d := 1000 * time.Hour
if !deadline.IsZero() {
d = time.Until(deadline)
d = deadline.Sub(time.Now())
if d < 0 {
return errWriteTimeout
}
@@ -471,15 +460,13 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err
}
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
c.conn.SetWriteDeadline(deadline)
_, err = c.conn.Write(buf)
if err != nil {
return c.writeFatal(err)
}
if messageType == CloseMessage {
_ = c.writeFatal(ErrCloseSent)
c.writeFatal(ErrCloseSent)
}
return err
}
@@ -490,9 +477,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
if c.writer != nil {
if err := c.writer.Close(); err != nil {
log.Printf("websocket: discarding writer close error: %v", err)
}
c.writer.Close()
c.writer = nil
}
@@ -645,7 +630,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
}
if final {
_ = w.endMessage(errWriteClosed)
w.endMessage(errWriteClosed)
return nil
}
@@ -810,7 +795,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame.
if c.readRemaining > 0 {
if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err
}
}
@@ -832,9 +817,7 @@ func (c *Conn) advanceFrame() (int, error) {
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
if err := c.setReadRemaining(int64(p[1] & 0x7f)); err != nil {
return noFrame, err
}
c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false
if rsv1 {
@@ -939,9 +922,7 @@ func (c *Conn) advanceFrame() (int, error) {
}
if c.readLimit > 0 && c.readLength > c.readLimit {
if err := c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)); err != nil {
return noFrame, err
}
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
}
@@ -953,9 +934,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte
if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining))
if err := c.setReadRemaining(0); err != nil {
return noFrame, err
}
c.setReadRemaining(0)
if err != nil {
return noFrame, err
}
@@ -1002,9 +981,7 @@ func (c *Conn) handleProtocolError(message string) error {
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
if err := c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)); err != nil {
return err
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}
@@ -1021,9 +998,7 @@ func (c *Conn) handleProtocolError(message string) error {
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
// Close previous reader, only relevant for decompression.
if c.reader != nil {
if err := c.reader.Close(); err != nil {
log.Printf("websocket: discarding reader close error: %v", err)
}
c.reader.Close()
c.reader = nil
}
@@ -1079,9 +1054,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
}
rem := c.readRemaining
rem -= int64(n)
if err := c.setReadRemaining(rem); err != nil {
return 0, err
}
c.setReadRemaining(rem)
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
@@ -1121,7 +1094,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil {
return messageType, nil, err
}
p, err = io.ReadAll(r)
p, err = ioutil.ReadAll(r)
return messageType, p, err
}
@@ -1163,9 +1136,7 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil {
h = func(code int, text string) error {
message := FormatCloseMessage(code, "")
if err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)); err != nil {
return err
}
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil
}
}
@@ -1190,7 +1161,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
if err == ErrCloseSent {
return nil
} else if _, ok := err.(net.Error); ok {
} else if e, ok := err.(net.Error); ok && e.Temporary() {
return nil
}
return err

View File

@@ -9,7 +9,6 @@ package websocket
import "unsafe"
// #nosec G103 -- (CWE-242) Has been audited
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int {
@@ -23,7 +22,6 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
}
// Mask one byte at a time to word boundary.
//#nosec G103 -- (CWE-242) Has been audited
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
@@ -38,13 +36,11 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range k {
k[i] = key[(pos+i)&3]
}
//#nosec G103 -- (CWE-242) Has been audited
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
//#nosec G103 -- (CWE-242) Has been audited
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}

View File

@@ -8,13 +8,10 @@ import (
"bufio"
"encoding/base64"
"errors"
"log"
"net"
"net/http"
"net/url"
"strings"
"golang.org/x/net/proxy"
)
type netDialerFunc func(network, addr string) (net.Conn, error)
@@ -24,7 +21,7 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
}
func init() {
proxy.RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy.Dialer) (proxy.Dialer, error) {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
})
}
@@ -58,9 +55,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
}
if err := connectReq.Write(conn); err != nil {
if err := conn.Close(); err != nil {
log.Printf("httpProxyDialer: failed to close connection: %v", err)
}
conn.Close()
return nil, err
}
@@ -69,16 +64,12 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
if err := conn.Close(); err != nil {
log.Printf("httpProxyDialer: failed to close connection: %v", err)
}
conn.Close()
return nil, err
}
if resp.StatusCode != 200 {
if err := conn.Close(); err != nil {
log.Printf("httpProxyDialer: failed to close connection: %v", err)
}
conn.Close()
f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1])
}

View File

@@ -8,7 +8,6 @@ import (
"bufio"
"errors"
"io"
"log"
"net/http"
"net/url"
"strings"
@@ -184,9 +183,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}
if brw.Reader.Buffered() > 0 {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
@@ -251,34 +248,17 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
if err := netConn.SetDeadline(time.Time{}); err != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err
}
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err
}
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err
}
netConn.SetWriteDeadline(time.Time{})
}
return c, nil
@@ -376,12 +356,8 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
if err := bw.WriteByte(0); err != nil {
panic(err)
}
if err := bw.Flush(); err != nil {
log.Printf("websocket: bufioWriterBuffer: Flush: %v", err)
}
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)

View File

@@ -1,3 +1,6 @@
//go:build go1.17
// +build go1.17
package websocket
import (

View File

@@ -6,7 +6,7 @@ package websocket
import (
"crypto/rand"
"crypto/sha1" //#nosec G505 -- (CWE-327) https://datatracker.ietf.org/doc/html/rfc6455#page-54
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
@@ -17,7 +17,7 @@ import (
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New() //#nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))

View File

@@ -1,39 +0,0 @@
language: go
os:
- linux
- osx
go:
- 1.16.x
install:
# go-flags
- go build -v ./...
# linting
- go get -v golang.org/x/lint/golint
# code coverage
- go get golang.org/x/tools/cmd/cover
- go get github.com/onsi/ginkgo/ginkgo
- go get github.com/modocache/gover
- if [ "$TRAVIS_SECURE_ENV_VARS" = "true" ]; then go get github.com/mattn/goveralls; fi
script:
# go-flags
- $(exit $(gofmt -l . | wc -l))
- go test -v ./...
# linting
- go tool vet -all=true -v=true . || true
- $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/golint ./...
# code coverage
- $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/ginkgo -r -cover
- $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/gover
- if [ "$TRAVIS_SECURE_ENV_VARS" = "true" ]; then $(go env GOPATH | awk 'BEGIN{FS=":"} {print $1}')/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci -repotoken $COVERALLS_TOKEN; fi
env:
# coveralls.io
secure: "RCYbiB4P0RjQRIoUx/vG/AjP3mmYCbzOmr86DCww1Z88yNcy3hYr3Cq8rpPtYU5v0g7wTpu4adaKIcqRE9xknYGbqj3YWZiCoBP1/n4Z+9sHW3Dsd9D/GRGeHUus0laJUGARjWoCTvoEtOgTdGQDoX7mH+pUUY0FBltNYUdOiiU="

View File

@@ -1,7 +1,7 @@
go-flags: a go library for parsing command line arguments
=========================================================
[![GoDoc](https://godoc.org/github.com/jessevdk/go-flags?status.png)](https://godoc.org/github.com/jessevdk/go-flags) [![Build Status](https://travis-ci.org/jessevdk/go-flags.svg?branch=master)](https://travis-ci.org/jessevdk/go-flags) [![Coverage Status](https://img.shields.io/coveralls/jessevdk/go-flags.svg)](https://coveralls.io/r/jessevdk/go-flags?branch=master)
[![GoDoc](https://godoc.org/github.com/jessevdk/go-flags?status.png)](https://godoc.org/github.com/jessevdk/go-flags)
This library provides similar functionality to the builtin flag library of
go, but provides much more functionality and nicer formatting. From the
@@ -78,6 +78,9 @@ var opts struct {
// Example of a map
IntMap map[string]int `long:"intmap" description:"A map from string to int"`
// Example of env variable
Thresholds []int `long:"thresholds" default:"1" default:"2" env:"THRESHOLD_VALUES" env-delim:","`
}
// Callback which will invoke callto:<argument> to call a number.

View File

@@ -30,6 +30,12 @@ type Command struct {
// Whether positional arguments are required
ArgsRequired bool
// Whether to pass all arguments after the first non option as remaining
// command line arguments. This is equivalent to strict POSIX processing.
// This is command-local version of PassAfterNonOption Parser flag. It
// cannot be turned off when PassAfterNonOption Parser flag is set.
PassAfterNonOption bool
commands []*Command
hasBuiltinHelpGroup bool
args []*Arg
@@ -244,6 +250,7 @@ func (c *Command) scanSubcommandHandler(parentg *Group) scanHandler {
longDescription := mtag.Get("long-description")
subcommandsOptional := mtag.Get("subcommands-optional")
aliases := mtag.GetMany("alias")
passAfterNonOption := mtag.Get("pass-after-non-option")
subc, err := c.AddCommand(subcommand, shortDescription, longDescription, ptrval.Interface())
@@ -261,6 +268,10 @@ func (c *Command) scanSubcommandHandler(parentg *Group) scanHandler {
subc.Aliases = aliases
}
if len(passAfterNonOption) > 0 {
subc.PassAfterNonOption = true
}
return true, nil
}

View File

@@ -53,7 +53,7 @@ func getBase(options multiTag, base int) (int, error) {
func convertMarshal(val reflect.Value) (bool, string, error) {
// Check first for the Marshaler interface
if val.Type().NumMethod() > 0 && val.CanInterface() {
if val.IsValid() && val.Type().NumMethod() > 0 && val.CanInterface() {
if marshaler, ok := val.Interface().(Marshaler); ok {
ret, err := marshaler.MarshalFlag()
return true, ret, err
@@ -68,6 +68,10 @@ func convertToString(val reflect.Value, options multiTag) (string, error) {
return ret, err
}
if !val.IsValid() {
return "", nil
}
tp := val.Type()
// Support for time.Duration
@@ -220,7 +224,7 @@ func convert(val string, retval reflect.Value, options multiTag) error {
retval.SetBool(b)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
base, err := getBase(options, 10)
base, err := getBase(options, 0)
if err != nil {
return err
@@ -234,7 +238,7 @@ func convert(val string, retval reflect.Value, options multiTag) error {
retval.SetInt(parsed)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
base, err := getBase(options, 10)
base, err := getBase(options, 0)
if err != nil {
return err
@@ -267,7 +271,12 @@ func convert(val string, retval reflect.Value, options multiTag) error {
retval.Set(reflect.Append(retval, elemval))
case reflect.Map:
parts := strings.SplitN(val, ":", 2)
keyValueDelimiter := options.Get("key-value-delimiter")
if keyValueDelimiter == "" {
keyValueDelimiter = ":"
}
parts := strings.SplitN(val, keyValueDelimiter, 2)
key := parts[0]
var value string

View File

@@ -8,8 +8,7 @@ The flags package is similar in functionality to the go built-in flag package
but provides more options and uses reflection to provide a convenient and
succinct way of specifying command line options.
Supported features
# Supported features
The following features are supported in go-flags:
@@ -31,6 +30,7 @@ The following features are supported in go-flags:
Supports namespaces for (nested) option groups
Additional features specific to Windows:
Options with short names (/v)
Options with long names (/verbose)
Windows-style options with arguments use a colon as the delimiter
@@ -38,8 +38,7 @@ Additional features specific to Windows:
Windows style options can be disabled at build time using the "forceposix"
build tag
Basic usage
# Basic usage
The flags package uses structs, reflection and struct field tags
to allow users to specify command line options. This results in very simple
@@ -71,8 +70,7 @@ Finally, for full control over the conversion between command line argument
values and options, user defined types can choose to implement the Marshaler
and Unmarshaler interfaces.
Available field tags
# Available field tags
The following is a list of tags for struct fields supported by go-flags:
@@ -159,8 +157,7 @@ The following is a list of tags for struct fields supported by go-flags:
Either the `short:` tag or the `long:` must be specified to make the field eligible as an
option.
Option groups
# Option groups
Option groups are a simple way to semantically separate your options. All
options in a particular group are shown together in the help under the name
@@ -174,9 +171,7 @@ There are currently three ways to specify option groups.
3. Add a struct field to the top-level options annotated with the
group:"group-name" tag.
Commands
# Commands
The flags package also has basic support for commands. Commands are often
used in monolithic applications that support various commands or actions.
@@ -211,8 +206,7 @@ However, if the -v flag is defined on the add command, then the first of
the two examples above would fail since the -v flag is not defined before
the add command.
Completion
# Completion
go-flags has builtin support to provide bash completion of flags, commands
and argument values. To use completion, the binary which uses go-flags

View File

@@ -72,15 +72,15 @@ func (p *Parser) getAlignmentInfo() alignmentInfo {
var prevcmd *Command
p.eachActiveGroup(func(c *Command, grp *Group) {
if !grp.showInHelp() {
return
}
if c != prevcmd {
for _, arg := range c.args {
ret.updateLen(arg.Name, c != p.Command)
}
prevcmd = c
}
if !grp.showInHelp() {
return
}
for _, info := range grp.options {
if !info.showInHelp() {
continue
@@ -334,7 +334,11 @@ func (p *Parser) WriteHelp(writer io.Writer) {
}
if !allcmd.ArgsRequired {
if arg.Required > 0 {
fmt.Fprintf(wr, "%s", name)
} else {
fmt.Fprintf(wr, "[%s]", name)
}
} else {
fmt.Fprintf(wr, "%s", name)
}

View File

@@ -584,7 +584,7 @@ func (i *IniParser) parse(ini *ini) error {
if i.ParseAsDefaults {
err = opt.setDefault(pval)
} else {
err = opt.set(pval)
err = opt.Set(pval)
}
if err != nil {

View File

@@ -239,7 +239,7 @@ func (option *Option) IsSetDefault() bool {
// Set the value of an option to the specified value. An error will be returned
// if the specified value could not be converted to the corresponding option
// value type.
func (option *Option) set(value *string) error {
func (option *Option) Set(value *string) error {
kind := option.value.Type().Kind()
if (kind == reflect.Map || kind == reflect.Slice) && option.clearReferenceBeforeSet {
@@ -287,7 +287,7 @@ func (option *Option) setDefault(value *string) error {
return nil
}
if err := option.set(value); err != nil {
if err := option.Set(value); err != nil {
return err
}

View File

@@ -1,3 +1,4 @@
//go:build !windows || forceposix
// +build !windows forceposix
package flags

View File

@@ -1,3 +1,4 @@
//go:build !forceposix
// +build !forceposix
package flags

View File

@@ -113,6 +113,10 @@ const (
// POSIX processing.
PassAfterNonOption
// AllowBoolValues allows a user to assign true/false to a boolean value
// rather than raising an error stating it cannot have an argument.
AllowBoolValues
// Default is a convenient default set of options which should cover
// most of the uses of the flags package.
Default = HelpFlag | PrintErrors | PassDoubleDash
@@ -252,7 +256,7 @@ func (p *Parser) ParseArgs(args []string) ([]string, error) {
}
if !argumentIsOption(arg) {
if (p.Options&PassAfterNonOption) != None && s.lookup.commands[arg] == nil {
if ((p.Options&PassAfterNonOption) != None || s.command.PassAfterNonOption) && s.lookup.commands[arg] == nil {
// If PassAfterNonOption is set then all remaining arguments
// are considered positional
if err = s.addArgs(s.arg); err != nil {
@@ -521,11 +525,10 @@ func (p *parseState) estimateCommand() error {
func (p *Parser) parseOption(s *parseState, name string, option *Option, canarg bool, argument *string) (err error) {
if !option.canArgument() {
if argument != nil {
if argument != nil && (p.Options&AllowBoolValues) == None {
return newErrorf(ErrNoArgumentForBool, "bool flag `%s' cannot have an argument", option)
}
err = option.set(nil)
err = option.Set(argument)
} else if argument != nil || (canarg && !s.eof()) {
var arg string
@@ -546,13 +549,13 @@ func (p *Parser) parseOption(s *parseState, name string, option *Option, canarg
}
if err == nil {
err = option.set(&arg)
err = option.Set(&arg)
}
} else if option.OptionalArgument {
option.empty()
for _, v := range option.OptionalValue {
err = option.set(&v)
err = option.Set(&v)
if err != nil {
break

View File

@@ -1,4 +1,5 @@
// +build !windows,!plan9,!appengine,!wasm
//go:build !windows && !plan9 && !appengine && !wasm && !aix
// +build !windows,!plan9,!appengine,!wasm,!aix
package flags

View File

@@ -1,4 +1,5 @@
// +build plan9 appengine wasm
//go:build plan9 || appengine || wasm || aix
// +build plan9 appengine wasm aix
package flags

View File

@@ -1,3 +1,4 @@
//go:build windows
// +build windows
package flags

View File

@@ -16,8 +16,3 @@ Go support for Protocol Buffers - Google's data interchange format
http://github.com/golang/protobuf/
Copyright 2010 The Go Authors
See source code for license details.
Support for streaming Protocol Buffer messages for the Go language (golang).
https://github.com/matttproud/golang_protobuf_extensions
Copyright 2013 Matt T. Proud
Licensed under the Apache License, Version 2.0

View File

@@ -95,7 +95,8 @@ func (v2) NewDesc(fqName, help string, variableLabels ConstrainableLabels, const
help: help,
variableLabels: variableLabels.compile(),
}
if !model.IsValidMetricName(model.LabelValue(fqName)) {
//nolint:staticcheck // TODO: Don't use deprecated model.NameValidationScheme.
if !model.NameValidationScheme.IsValidMetricName(fqName) {
d.err = fmt.Errorf("%q is not a valid metric name", fqName)
return d
}
@@ -189,7 +190,9 @@ func (d *Desc) String() string {
fmt.Sprintf("%s=%q", lp.GetName(), lp.GetValue()),
)
}
vlStrings := make([]string, 0, len(d.variableLabels.names))
vlStrings := []string{}
if d.variableLabels != nil {
vlStrings = make([]string, 0, len(d.variableLabels.names))
for _, vl := range d.variableLabels.names {
if fn, ok := d.variableLabels.labelConstraints[vl]; ok && fn != nil {
vlStrings = append(vlStrings, fmt.Sprintf("c(%s)", vl))
@@ -197,6 +200,7 @@ func (d *Desc) String() string {
vlStrings = append(vlStrings, vl)
}
}
}
return fmt.Sprintf(
"Desc{fqName: %q, help: %q, constLabels: {%s}, variableLabels: {%s}}",
d.fqName,

View File

@@ -22,13 +22,13 @@ import (
// goRuntimeMemStats provides the metrics initially provided by runtime.ReadMemStats.
// From Go 1.17 those similar (and better) statistics are provided by runtime/metrics, so
// while eval closure works on runtime.MemStats, the struct from Go 1.17+ is
// populated using runtime/metrics.
// populated using runtime/metrics. Those are the defaults we can't alter.
func goRuntimeMemStats() memStatsMetrics {
return memStatsMetrics{
{
desc: NewDesc(
memstatNamespace("alloc_bytes"),
"Number of bytes allocated and still in use.",
"Number of bytes allocated in heap and currently in use. Equals to /memory/classes/heap/objects:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.Alloc) },
@@ -36,7 +36,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("alloc_bytes_total"),
"Total number of bytes allocated, even if freed.",
"Total number of bytes allocated in heap until now, even if released already. Equals to /gc/heap/allocs:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.TotalAlloc) },
@@ -44,23 +44,16 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("sys_bytes"),
"Number of bytes obtained from system.",
"Number of bytes obtained from system. Equals to /memory/classes/total:byte.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.Sys) },
valType: GaugeValue,
}, {
desc: NewDesc(
memstatNamespace("lookups_total"),
"Total number of pointer lookups.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.Lookups) },
valType: CounterValue,
}, {
desc: NewDesc(
memstatNamespace("mallocs_total"),
"Total number of mallocs.",
// TODO(bwplotka): We could add go_memstats_heap_objects, probably useful for discovery. Let's gather more feedback, kind of a waste of bytes for everybody for compatibility reasons to keep both, and we can't really rename/remove useful metric.
"Total number of heap objects allocated, both live and gc-ed. Semantically a counter version for go_memstats_heap_objects gauge. Equals to /gc/heap/allocs:objects + /gc/heap/tiny/allocs:objects.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.Mallocs) },
@@ -68,7 +61,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("frees_total"),
"Total number of frees.",
"Total number of heap objects frees. Equals to /gc/heap/frees:objects + /gc/heap/tiny/allocs:objects.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.Frees) },
@@ -76,7 +69,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("heap_alloc_bytes"),
"Number of heap bytes allocated and still in use.",
"Number of heap bytes allocated and currently in use, same as go_memstats_alloc_bytes. Equals to /memory/classes/heap/objects:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.HeapAlloc) },
@@ -84,7 +77,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("heap_sys_bytes"),
"Number of heap bytes obtained from system.",
"Number of heap bytes obtained from system. Equals to /memory/classes/heap/objects:bytes + /memory/classes/heap/unused:bytes + /memory/classes/heap/released:bytes + /memory/classes/heap/free:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.HeapSys) },
@@ -92,7 +85,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("heap_idle_bytes"),
"Number of heap bytes waiting to be used.",
"Number of heap bytes waiting to be used. Equals to /memory/classes/heap/released:bytes + /memory/classes/heap/free:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.HeapIdle) },
@@ -100,7 +93,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("heap_inuse_bytes"),
"Number of heap bytes that are in use.",
"Number of heap bytes that are in use. Equals to /memory/classes/heap/objects:bytes + /memory/classes/heap/unused:bytes",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.HeapInuse) },
@@ -108,7 +101,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("heap_released_bytes"),
"Number of heap bytes released to OS.",
"Number of heap bytes released to OS. Equals to /memory/classes/heap/released:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.HeapReleased) },
@@ -116,7 +109,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("heap_objects"),
"Number of allocated objects.",
"Number of currently allocated objects. Equals to /gc/heap/objects:objects.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.HeapObjects) },
@@ -124,7 +117,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("stack_inuse_bytes"),
"Number of bytes in use by the stack allocator.",
"Number of bytes obtained from system for stack allocator in non-CGO environments. Equals to /memory/classes/heap/stacks:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.StackInuse) },
@@ -132,7 +125,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("stack_sys_bytes"),
"Number of bytes obtained from system for stack allocator.",
"Number of bytes obtained from system for stack allocator. Equals to /memory/classes/heap/stacks:bytes + /memory/classes/os-stacks:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.StackSys) },
@@ -140,7 +133,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("mspan_inuse_bytes"),
"Number of bytes in use by mspan structures.",
"Number of bytes in use by mspan structures. Equals to /memory/classes/metadata/mspan/inuse:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.MSpanInuse) },
@@ -148,7 +141,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("mspan_sys_bytes"),
"Number of bytes used for mspan structures obtained from system.",
"Number of bytes used for mspan structures obtained from system. Equals to /memory/classes/metadata/mspan/inuse:bytes + /memory/classes/metadata/mspan/free:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.MSpanSys) },
@@ -156,7 +149,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("mcache_inuse_bytes"),
"Number of bytes in use by mcache structures.",
"Number of bytes in use by mcache structures. Equals to /memory/classes/metadata/mcache/inuse:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.MCacheInuse) },
@@ -164,7 +157,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("mcache_sys_bytes"),
"Number of bytes used for mcache structures obtained from system.",
"Number of bytes used for mcache structures obtained from system. Equals to /memory/classes/metadata/mcache/inuse:bytes + /memory/classes/metadata/mcache/free:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.MCacheSys) },
@@ -172,7 +165,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("buck_hash_sys_bytes"),
"Number of bytes used by the profiling bucket hash table.",
"Number of bytes used by the profiling bucket hash table. Equals to /memory/classes/profiling/buckets:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.BuckHashSys) },
@@ -180,7 +173,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("gc_sys_bytes"),
"Number of bytes used for garbage collection system metadata.",
"Number of bytes used for garbage collection system metadata. Equals to /memory/classes/metadata/other:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.GCSys) },
@@ -188,7 +181,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("other_sys_bytes"),
"Number of bytes used for other system allocations.",
"Number of bytes used for other system allocations. Equals to /memory/classes/other:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.OtherSys) },
@@ -196,7 +189,7 @@ func goRuntimeMemStats() memStatsMetrics {
}, {
desc: NewDesc(
memstatNamespace("next_gc_bytes"),
"Number of heap bytes when next garbage collection will take place.",
"Number of heap bytes when next garbage collection will take place. Equals to /gc/heap/goal:bytes.",
nil, nil,
),
eval: func(ms *runtime.MemStats) float64 { return float64(ms.NextGC) },
@@ -225,7 +218,7 @@ func newBaseGoCollector() baseGoCollector {
nil, nil),
gcDesc: NewDesc(
"go_gc_duration_seconds",
"A summary of the pause duration of garbage collection cycles.",
"A summary of the wall-time pause (stop-the-world) duration in garbage collection cycles.",
nil, nil),
gcLastTimeDesc: NewDesc(
"go_memstats_last_gc_time_seconds",

View File

@@ -17,6 +17,7 @@
package prometheus
import (
"fmt"
"math"
"runtime"
"runtime/metrics"
@@ -153,7 +154,8 @@ func defaultGoCollectorOptions() internal.GoCollectorOptions {
"/gc/heap/frees-by-size:bytes": goGCHeapFreesBytes,
},
RuntimeMetricRules: []internal.GoCollectorRule{
//{Matcher: regexp.MustCompile("")},
// Recommended metrics we want by default from runtime/metrics.
{Matcher: internal.GoCollectorDefaultRuntimeMetrics},
},
}
}
@@ -203,6 +205,7 @@ func NewGoCollector(opts ...func(o *internal.GoCollectorOptions)) Collector {
// to fail here. This condition is tested in TestExpectedRuntimeMetrics.
continue
}
help := attachOriginalName(d.Description.Description, d.Name)
sampleBuf = append(sampleBuf, metrics.Sample{Name: d.Name})
sampleMap[d.Name] = &sampleBuf[len(sampleBuf)-1]
@@ -214,7 +217,7 @@ func NewGoCollector(opts ...func(o *internal.GoCollectorOptions)) Collector {
m = newBatchHistogram(
NewDesc(
BuildFQName(namespace, subsystem, name),
d.Description.Description,
help,
nil,
nil,
),
@@ -226,7 +229,7 @@ func NewGoCollector(opts ...func(o *internal.GoCollectorOptions)) Collector {
Namespace: namespace,
Subsystem: subsystem,
Name: name,
Help: d.Description.Description,
Help: help,
},
)
} else {
@@ -234,7 +237,7 @@ func NewGoCollector(opts ...func(o *internal.GoCollectorOptions)) Collector {
Namespace: namespace,
Subsystem: subsystem,
Name: name,
Help: d.Description.Description,
Help: help,
})
}
metricSet = append(metricSet, m)
@@ -284,6 +287,10 @@ func NewGoCollector(opts ...func(o *internal.GoCollectorOptions)) Collector {
}
}
func attachOriginalName(desc, origName string) string {
return fmt.Sprintf("%s Sourced from %s.", desc, origName)
}
// Describe returns all descriptions of the collector.
func (c *goCollector) Describe(ch chan<- *Desc) {
c.base.Describe(ch)
@@ -376,13 +383,13 @@ func unwrapScalarRMValue(v metrics.Value) float64 {
//
// This should never happen because we always populate our metric
// set from the runtime/metrics package.
panic("unexpected unsupported metric")
panic("unexpected bad kind metric")
default:
// Unsupported metric kind.
//
// This should never happen because we check for this during initialization
// and flag and filter metrics whose kinds we don't understand.
panic("unexpected unsupported metric kind")
panic(fmt.Sprintf("unexpected unsupported metric: %v", v.Kind()))
}
}

View File

@@ -14,6 +14,7 @@
package prometheus
import (
"errors"
"fmt"
"math"
"runtime"
@@ -28,6 +29,11 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
nativeHistogramSchemaMaximum = 8
nativeHistogramSchemaMinimum = -4
)
// nativeHistogramBounds for the frac of observed values. Only relevant for
// schema > 0. The position in the slice is the schema. (0 is never used, just
// here for convenience of using the schema directly as the index.)
@@ -330,11 +336,11 @@ func ExponentialBuckets(start, factor float64, count int) []float64 {
// used for the Buckets field of HistogramOpts.
//
// The function panics if 'count' is 0 or negative, if 'min' is 0 or negative.
func ExponentialBucketsRange(min, max float64, count int) []float64 {
func ExponentialBucketsRange(minBucket, maxBucket float64, count int) []float64 {
if count < 1 {
panic("ExponentialBucketsRange count needs a positive count")
}
if min <= 0 {
if minBucket <= 0 {
panic("ExponentialBucketsRange min needs to be greater than 0")
}
@@ -342,12 +348,12 @@ func ExponentialBucketsRange(min, max float64, count int) []float64 {
// max = min*growthFactor^(bucketCount-1)
// We know max/min and highest bucket. Solve for growthFactor.
growthFactor := math.Pow(max/min, 1.0/float64(count-1))
growthFactor := math.Pow(maxBucket/minBucket, 1.0/float64(count-1))
// Now that we know growthFactor, solve for each bucket.
buckets := make([]float64, count)
for i := 1; i <= count; i++ {
buckets[i-1] = min * math.Pow(growthFactor, float64(i-1))
buckets[i-1] = minBucket * math.Pow(growthFactor, float64(i-1))
}
return buckets
}
@@ -440,7 +446,7 @@ type HistogramOpts struct {
// constant (or any negative float value).
NativeHistogramZeroThreshold float64
// The remaining fields define a strategy to limit the number of
// The next three fields define a strategy to limit the number of
// populated sparse buckets. If NativeHistogramMaxBucketNumber is left
// at zero, the number of buckets is not limited. (Note that this might
// lead to unbounded memory consumption if the values observed by the
@@ -473,6 +479,22 @@ type HistogramOpts struct {
NativeHistogramMinResetDuration time.Duration
NativeHistogramMaxZeroThreshold float64
// NativeHistogramMaxExemplars limits the number of exemplars
// that are kept in memory for each native histogram. If you leave it at
// zero, a default value of 10 is used. If no exemplars should be kept specifically
// for native histograms, set it to a negative value. (Scrapers can
// still use the exemplars exposed for classic buckets, which are managed
// independently.)
NativeHistogramMaxExemplars int
// NativeHistogramExemplarTTL is only checked once
// NativeHistogramMaxExemplars is exceeded. In that case, the
// oldest exemplar is removed if it is older than NativeHistogramExemplarTTL.
// Otherwise, the older exemplar in the pair of exemplars that are closest
// together (on an exponential scale) is removed.
// If NativeHistogramExemplarTTL is left at its zero value, a default value of
// 5m is used. To always delete the oldest exemplar, set it to a negative value.
NativeHistogramExemplarTTL time.Duration
// now is for testing purposes, by default it's time.Now.
now func() time.Time
@@ -532,6 +554,7 @@ func newHistogram(desc *Desc, opts HistogramOpts, labelValues ...string) Histogr
if opts.afterFunc == nil {
opts.afterFunc = time.AfterFunc
}
h := &histogram{
desc: desc,
upperBounds: opts.Buckets,
@@ -556,6 +579,7 @@ func newHistogram(desc *Desc, opts HistogramOpts, labelValues ...string) Histogr
h.nativeHistogramZeroThreshold = DefNativeHistogramZeroThreshold
} // Leave h.nativeHistogramZeroThreshold at 0 otherwise.
h.nativeHistogramSchema = pickSchema(opts.NativeHistogramBucketFactor)
h.nativeExemplars = makeNativeExemplars(opts.NativeHistogramExemplarTTL, opts.NativeHistogramMaxExemplars)
}
for i, upperBound := range h.upperBounds {
if i < len(h.upperBounds)-1 {
@@ -726,6 +750,7 @@ type histogram struct {
// scheduled for a later time (when nativeHistogramMinResetDuration has
// passed).
resetScheduled bool
nativeExemplars nativeExemplars
// now is for testing purposes, by default it's time.Now.
now func() time.Time
@@ -742,6 +767,9 @@ func (h *histogram) Observe(v float64) {
h.observe(v, h.findBucket(v))
}
// ObserveWithExemplar should not be called in a high-frequency setting
// for a native histogram with configured exemplars. For this case,
// the implementation isn't lock-free and might suffer from lock contention.
func (h *histogram) ObserveWithExemplar(v float64, e Labels) {
i := h.findBucket(v)
h.observe(v, i)
@@ -821,6 +849,13 @@ func (h *histogram) Write(out *dto.Metric) error {
Length: proto.Uint32(0),
}}
}
if h.nativeExemplars.isEnabled() {
h.nativeExemplars.Lock()
his.Exemplars = append(his.Exemplars, h.nativeExemplars.exemplars...)
h.nativeExemplars.Unlock()
}
}
addAndResetCounts(hotCounts, coldCounts)
return nil
@@ -829,15 +864,35 @@ func (h *histogram) Write(out *dto.Metric) error {
// findBucket returns the index of the bucket for the provided value, or
// len(h.upperBounds) for the +Inf bucket.
func (h *histogram) findBucket(v float64) int {
// TODO(beorn7): For small numbers of buckets (<30), a linear search is
// slightly faster than the binary search. If we really care, we could
// switch from one search strategy to the other depending on the number
// of buckets.
//
// Microbenchmarks (BenchmarkHistogramNoLabels):
// 11 buckets: 38.3 ns/op linear - binary 48.7 ns/op
// 100 buckets: 78.1 ns/op linear - binary 54.9 ns/op
// 300 buckets: 154 ns/op linear - binary 61.6 ns/op
n := len(h.upperBounds)
if n == 0 {
return 0
}
// Early exit: if v is less than or equal to the first upper bound, return 0
if v <= h.upperBounds[0] {
return 0
}
// Early exit: if v is greater than the last upper bound, return len(h.upperBounds)
if v > h.upperBounds[n-1] {
return n
}
// For small arrays, use simple linear search
// "magic number" 35 is result of tests on couple different (AWS and baremetal) servers
// see more details here: https://github.com/prometheus/client_golang/pull/1662
if n < 35 {
for i, bound := range h.upperBounds {
if v <= bound {
return i
}
}
// If v is greater than all upper bounds, return len(h.upperBounds)
return n
}
// For larger arrays, use stdlib's binary search
return sort.SearchFloat64s(h.upperBounds, v)
}
@@ -1091,8 +1146,10 @@ func (h *histogram) resetCounts(counts *histogramCounts) {
deleteSyncMap(&counts.nativeHistogramBucketsPositive)
}
// updateExemplar replaces the exemplar for the provided bucket. With empty
// labels, it's a no-op. It panics if any of the labels is invalid.
// updateExemplar replaces the exemplar for the provided classic bucket.
// With empty labels, it's a no-op. It panics if any of the labels is invalid.
// If histogram is native, the exemplar will be cached into nativeExemplars,
// which has a limit, and will remove one exemplar when limit is reached.
func (h *histogram) updateExemplar(v float64, bucket int, l Labels) {
if l == nil {
return
@@ -1102,6 +1159,10 @@ func (h *histogram) updateExemplar(v float64, bucket int, l Labels) {
panic(err)
}
h.exemplars[bucket].Store(e)
doSparse := h.nativeHistogramSchema > math.MinInt32 && !math.IsNaN(v)
if doSparse {
h.nativeExemplars.addExemplar(e)
}
}
// HistogramVec is a Collector that bundles a set of Histograms that all share the
@@ -1336,6 +1397,48 @@ func MustNewConstHistogram(
return m
}
// NewConstHistogramWithCreatedTimestamp does the same thing as NewConstHistogram but sets the created timestamp.
func NewConstHistogramWithCreatedTimestamp(
desc *Desc,
count uint64,
sum float64,
buckets map[float64]uint64,
ct time.Time,
labelValues ...string,
) (Metric, error) {
if desc.err != nil {
return nil, desc.err
}
if err := validateLabelValues(labelValues, len(desc.variableLabels.names)); err != nil {
return nil, err
}
return &constHistogram{
desc: desc,
count: count,
sum: sum,
buckets: buckets,
labelPairs: MakeLabelPairs(desc, labelValues),
createdTs: timestamppb.New(ct),
}, nil
}
// MustNewConstHistogramWithCreatedTimestamp is a version of NewConstHistogramWithCreatedTimestamp that panics where
// NewConstHistogramWithCreatedTimestamp would have returned an error.
func MustNewConstHistogramWithCreatedTimestamp(
desc *Desc,
count uint64,
sum float64,
buckets map[float64]uint64,
ct time.Time,
labelValues ...string,
) Metric {
m, err := NewConstHistogramWithCreatedTimestamp(desc, count, sum, buckets, ct, labelValues...)
if err != nil {
panic(err)
}
return m
}
type buckSort []*dto.Bucket
func (s buckSort) Len() int {
@@ -1363,9 +1466,9 @@ func pickSchema(bucketFactor float64) int32 {
floor := math.Floor(math.Log2(math.Log2(bucketFactor)))
switch {
case floor <= -8:
return 8
return nativeHistogramSchemaMaximum
case floor >= 4:
return -4
return nativeHistogramSchemaMinimum
default:
return -int32(floor)
}
@@ -1575,3 +1678,379 @@ func addAndResetCounts(hot, cold *histogramCounts) {
atomic.AddUint64(&hot.nativeHistogramZeroBucket, atomic.LoadUint64(&cold.nativeHistogramZeroBucket))
atomic.StoreUint64(&cold.nativeHistogramZeroBucket, 0)
}
type nativeExemplars struct {
sync.Mutex
// Time-to-live for exemplars, it is set to -1 if exemplars are disabled, that is NativeHistogramMaxExemplars is below 0.
// The ttl is used on insertion to remove an exemplar that is older than ttl, if present.
ttl time.Duration
exemplars []*dto.Exemplar
}
func (n *nativeExemplars) isEnabled() bool {
return n.ttl != -1
}
func makeNativeExemplars(ttl time.Duration, maxCount int) nativeExemplars {
if ttl == 0 {
ttl = 5 * time.Minute
}
if maxCount == 0 {
maxCount = 10
}
if maxCount < 0 {
maxCount = 0
ttl = -1
}
return nativeExemplars{
ttl: ttl,
exemplars: make([]*dto.Exemplar, 0, maxCount),
}
}
func (n *nativeExemplars) addExemplar(e *dto.Exemplar) {
if !n.isEnabled() {
return
}
n.Lock()
defer n.Unlock()
// When the number of exemplars has not yet exceeded or
// is equal to cap(n.exemplars), then
// insert the new exemplar directly.
if len(n.exemplars) < cap(n.exemplars) {
var nIdx int
for nIdx = 0; nIdx < len(n.exemplars); nIdx++ {
if *e.Value < *n.exemplars[nIdx].Value {
break
}
}
n.exemplars = append(n.exemplars[:nIdx], append([]*dto.Exemplar{e}, n.exemplars[nIdx:]...)...)
return
}
if len(n.exemplars) == 1 {
// When the number of exemplars is 1, then
// replace the existing exemplar with the new exemplar.
n.exemplars[0] = e
return
}
// From this point on, the number of exemplars is greater than 1.
// When the number of exemplars exceeds the limit, remove one exemplar.
var (
ot = time.Time{} // Oldest timestamp seen. Initial value doesn't matter as we replace it due to otIdx == -1 in the loop.
otIdx = -1 // Index of the exemplar with the oldest timestamp.
md = -1.0 // Logarithm of the delta of the closest pair of exemplars.
// The insertion point of the new exemplar in the exemplars slice after insertion.
// This is calculated purely based on the order of the exemplars by value.
// nIdx == len(n.exemplars) means the new exemplar is to be inserted after the end.
nIdx = -1
// rIdx is ultimately the index for the exemplar that we are replacing with the new exemplar.
// The aim is to keep a good spread of exemplars by value and not let them bunch up too much.
// It is calculated in 3 steps:
// 1. First we set rIdx to the index of the older exemplar within the closest pair by value.
// That is the following will be true (on log scale):
// either the exemplar pair on index (rIdx-1, rIdx) or (rIdx, rIdx+1) will have
// the closest values to each other from all pairs.
// For example, suppose the values are distributed like this:
// |-----------x-------------x----------------x----x-----|
// ^--rIdx as this is older.
// Or like this:
// |-----------x-------------x----------------x----x-----|
// ^--rIdx as this is older.
// 2. If there is an exemplar that expired, then we simple reset rIdx to that index.
// 3. We check if by inserting the new exemplar we would create a closer pair at
// (nIdx-1, nIdx) or (nIdx, nIdx+1) and set rIdx to nIdx-1 or nIdx accordingly to
// keep the spread of exemplars by value; otherwise we keep rIdx as it is.
rIdx = -1
cLog float64 // Logarithm of the current exemplar.
pLog float64 // Logarithm of the previous exemplar.
)
for i, exemplar := range n.exemplars {
// Find the exemplar with the oldest timestamp.
if otIdx == -1 || exemplar.Timestamp.AsTime().Before(ot) {
ot = exemplar.Timestamp.AsTime()
otIdx = i
}
// Find the index at which to insert new the exemplar.
if nIdx == -1 && *e.Value <= *exemplar.Value {
nIdx = i
}
// Find the two closest exemplars and pick the one the with older timestamp.
pLog = cLog
cLog = math.Log(exemplar.GetValue())
if i == 0 {
continue
}
diff := math.Abs(cLog - pLog)
if md == -1 || diff < md {
// The closest exemplar pair is at index: i-1, i.
// Choose the exemplar with the older timestamp for replacement.
md = diff
if n.exemplars[i].Timestamp.AsTime().Before(n.exemplars[i-1].Timestamp.AsTime()) {
rIdx = i
} else {
rIdx = i - 1
}
}
}
// If all existing exemplar are smaller than new exemplar,
// then the exemplar should be inserted at the end.
if nIdx == -1 {
nIdx = len(n.exemplars)
}
// Here, we have the following relationships:
// n.exemplars[nIdx-1].Value < e.Value (if nIdx > 0)
// e.Value <= n.exemplars[nIdx].Value (if nIdx < len(n.exemplars))
if otIdx != -1 && e.Timestamp.AsTime().Sub(ot) > n.ttl {
// If the oldest exemplar has expired, then replace it with the new exemplar.
rIdx = otIdx
} else {
// In the previous for loop, when calculating the closest pair of exemplars,
// we did not take into account the newly inserted exemplar.
// So we need to calculate with the newly inserted exemplar again.
elog := math.Log(e.GetValue())
if nIdx > 0 {
diff := math.Abs(elog - math.Log(n.exemplars[nIdx-1].GetValue()))
if diff < md {
// The value we are about to insert is closer to the previous exemplar at the insertion point than what we calculated before in rIdx.
// v--rIdx
// |-----------x-n-----------x----------------x----x-----|
// nIdx-1--^ ^--new exemplar value
// Do not make the spread worse, replace nIdx-1 and not rIdx.
md = diff
rIdx = nIdx - 1
}
}
if nIdx < len(n.exemplars) {
diff := math.Abs(math.Log(n.exemplars[nIdx].GetValue()) - elog)
if diff < md {
// The value we are about to insert is closer to the next exemplar at the insertion point than what we calculated before in rIdx.
// v--rIdx
// |-----------x-----------n-x----------------x----x-----|
// new exemplar value--^ ^--nIdx
// Do not make the spread worse, replace nIdx-1 and not rIdx.
rIdx = nIdx
}
}
}
// Adjust the slice according to rIdx and nIdx.
switch {
case rIdx == nIdx:
n.exemplars[nIdx] = e
case rIdx < nIdx:
n.exemplars = append(n.exemplars[:rIdx], append(n.exemplars[rIdx+1:nIdx], append([]*dto.Exemplar{e}, n.exemplars[nIdx:]...)...)...)
case rIdx > nIdx:
n.exemplars = append(n.exemplars[:nIdx], append([]*dto.Exemplar{e}, append(n.exemplars[nIdx:rIdx], n.exemplars[rIdx+1:]...)...)...)
}
}
type constNativeHistogram struct {
desc *Desc
dto.Histogram
labelPairs []*dto.LabelPair
}
func validateCount(sum float64, count uint64, negativeBuckets, positiveBuckets map[int]int64, zeroBucket uint64) error {
var bucketPopulationSum int64
for _, v := range positiveBuckets {
bucketPopulationSum += v
}
for _, v := range negativeBuckets {
bucketPopulationSum += v
}
bucketPopulationSum += int64(zeroBucket)
// If the sum of observations is NaN, the number of observations must be greater or equal to the sum of all bucket counts.
// Otherwise, the number of observations must be equal to the sum of all bucket counts .
if math.IsNaN(sum) && bucketPopulationSum > int64(count) ||
!math.IsNaN(sum) && bucketPopulationSum != int64(count) {
return errors.New("the sum of all bucket populations exceeds the count of observations")
}
return nil
}
// NewConstNativeHistogram returns a metric representing a Prometheus native histogram with
// fixed values for the count, sum, and positive/negative/zero bucket counts. As those parameters
// cannot be changed, the returned value does not implement the Histogram
// interface (but only the Metric interface). Users of this package will not
// have much use for it in regular operations. However, when implementing custom
// OpenTelemetry Collectors, it is useful as a throw-away metric that is generated on the fly
// to send it to Prometheus in the Collect method.
//
// zeroBucket counts all (positive and negative)
// observations in the zero bucket (with an absolute value less or equal
// the current threshold).
// positiveBuckets and negativeBuckets are separate maps for negative and positive
// observations. The map's value is an int64, counting observations in
// that bucket. The map's key is the
// index of the bucket according to the used
// Schema. Index 0 is for an upper bound of 1 in positive buckets and for a lower bound of -1 in negative buckets.
// NewConstNativeHistogram returns an error if
// - the length of labelValues is not consistent with the variable labels in Desc or if Desc is invalid.
// - the schema passed is not between 8 and -4
// - the sum of counts in all buckets including the zero bucket does not equal the count if sum is not NaN (or exceeds the count if sum is NaN)
//
// See https://opentelemetry.io/docs/specs/otel/compatibility/prometheus_and_openmetrics/#exponential-histograms for more details about the conversion from OTel to Prometheus.
func NewConstNativeHistogram(
desc *Desc,
count uint64,
sum float64,
positiveBuckets, negativeBuckets map[int]int64,
zeroBucket uint64,
schema int32,
zeroThreshold float64,
createdTimestamp time.Time,
labelValues ...string,
) (Metric, error) {
if desc.err != nil {
return nil, desc.err
}
if err := validateLabelValues(labelValues, len(desc.variableLabels.names)); err != nil {
return nil, err
}
if schema > nativeHistogramSchemaMaximum || schema < nativeHistogramSchemaMinimum {
return nil, errors.New("invalid native histogram schema")
}
if err := validateCount(sum, count, negativeBuckets, positiveBuckets, zeroBucket); err != nil {
return nil, err
}
NegativeSpan, NegativeDelta := makeBucketsFromMap(negativeBuckets)
PositiveSpan, PositiveDelta := makeBucketsFromMap(positiveBuckets)
ret := &constNativeHistogram{
desc: desc,
Histogram: dto.Histogram{
CreatedTimestamp: timestamppb.New(createdTimestamp),
Schema: &schema,
ZeroThreshold: &zeroThreshold,
SampleCount: &count,
SampleSum: &sum,
NegativeSpan: NegativeSpan,
NegativeDelta: NegativeDelta,
PositiveSpan: PositiveSpan,
PositiveDelta: PositiveDelta,
ZeroCount: proto.Uint64(zeroBucket),
},
labelPairs: MakeLabelPairs(desc, labelValues),
}
if *ret.ZeroThreshold == 0 && *ret.ZeroCount == 0 && len(ret.PositiveSpan) == 0 && len(ret.NegativeSpan) == 0 {
ret.PositiveSpan = []*dto.BucketSpan{{
Offset: proto.Int32(0),
Length: proto.Uint32(0),
}}
}
return ret, nil
}
// MustNewConstNativeHistogram is a version of NewConstNativeHistogram that panics where
// NewConstNativeHistogram would have returned an error.
func MustNewConstNativeHistogram(
desc *Desc,
count uint64,
sum float64,
positiveBuckets, negativeBuckets map[int]int64,
zeroBucket uint64,
nativeHistogramSchema int32,
nativeHistogramZeroThreshold float64,
createdTimestamp time.Time,
labelValues ...string,
) Metric {
nativehistogram, err := NewConstNativeHistogram(desc,
count,
sum,
positiveBuckets,
negativeBuckets,
zeroBucket,
nativeHistogramSchema,
nativeHistogramZeroThreshold,
createdTimestamp,
labelValues...)
if err != nil {
panic(err)
}
return nativehistogram
}
func (h *constNativeHistogram) Desc() *Desc {
return h.desc
}
func (h *constNativeHistogram) Write(out *dto.Metric) error {
out.Histogram = &h.Histogram
out.Label = h.labelPairs
return nil
}
func makeBucketsFromMap(buckets map[int]int64) ([]*dto.BucketSpan, []int64) {
if len(buckets) == 0 {
return nil, nil
}
var ii []int
for k := range buckets {
ii = append(ii, k)
}
sort.Ints(ii)
var (
spans []*dto.BucketSpan
deltas []int64
prevCount int64
nextI int
)
appendDelta := func(count int64) {
*spans[len(spans)-1].Length++
deltas = append(deltas, count-prevCount)
prevCount = count
}
for n, i := range ii {
count := buckets[i]
// Multiple spans with only small gaps in between are probably
// encoded more efficiently as one larger span with a few empty
// buckets. Needs some research to find the sweet spot. For now,
// we assume that gaps of one or two buckets should not create
// a new span.
iDelta := int32(i - nextI)
if n == 0 || iDelta > 2 {
// We have to create a new span, either because we are
// at the very beginning, or because we have found a gap
// of more than two buckets.
spans = append(spans, &dto.BucketSpan{
Offset: proto.Int32(iDelta),
Length: proto.Uint32(0),
})
} else {
// We have found a small gap (or no gap at all).
// Insert empty buckets as needed.
for j := int32(0); j < iDelta; j++ {
appendDelta(0)
}
}
appendDelta(count)
nextI = i + 1
}
return spans, deltas
}

View File

@@ -22,17 +22,18 @@ import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
)
func min(a, b int) int {
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
func maxInt(a, b int) int {
if a > b {
return a
}
@@ -427,12 +428,12 @@ func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
if codes[0].Tag == 'e' {
c := codes[0]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2}
codes[0] = OpCode{c.Tag, maxInt(i1, i2-n), i2, maxInt(j1, j2-n), j2}
}
if codes[len(codes)-1].Tag == 'e' {
c := codes[len(codes)-1]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)}
codes[len(codes)-1] = OpCode{c.Tag, i1, minInt(i2, i1+n), j1, minInt(j2, j1+n)}
}
nn := n + n
groups := [][]OpCode{}
@@ -443,16 +444,16 @@ func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
// there is a large range with no changes.
if c.Tag == 'e' && i2-i1 > nn {
group = append(group, OpCode{
c.Tag, i1, min(i2, i1+n),
j1, min(j2, j1+n),
c.Tag, i1, minInt(i2, i1+n),
j1, minInt(j2, j1+n),
})
groups = append(groups, group)
group = []OpCode{}
i1, j1 = max(i1, i2-n), max(j1, j2-n)
i1, j1 = maxInt(i1, i2-n), maxInt(j1, j2-n)
}
group = append(group, OpCode{c.Tag, i1, i2, j1, j2})
}
if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') {
if len(group) > 0 && (len(group) != 1 || group[0].Tag != 'e') {
groups = append(groups, group)
}
return groups
@@ -515,7 +516,7 @@ func (m *SequenceMatcher) QuickRatio() float64 {
// is faster to compute than either .Ratio() or .QuickRatio().
func (m *SequenceMatcher) RealQuickRatio() float64 {
la, lb := len(m.a), len(m.b)
return calculateRatio(min(la, lb), la+lb)
return calculateRatio(minInt(la, lb), la+lb)
}
// Convert range to the "ed" format
@@ -524,7 +525,7 @@ func formatRangeUnified(start, stop int) string {
beginning := start + 1 // lines start numbering with one
length := stop - start
if length == 1 {
return fmt.Sprintf("%d", beginning)
return strconv.Itoa(beginning)
}
if length == 0 {
beginning-- // empty ranges begin at line just before the range
@@ -567,7 +568,7 @@ func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error {
buf := bufio.NewWriter(writer)
defer buf.Flush()
wf := func(format string, args ...interface{}) error {
_, err := buf.WriteString(fmt.Sprintf(format, args...))
_, err := fmt.Fprintf(buf, format, args...)
return err
}
ws := func(s string) error {

View File

@@ -30,3 +30,5 @@ type GoCollectorOptions struct {
RuntimeMetricSumForHist map[string]string
RuntimeMetricRules []GoCollectorRule
}
var GoCollectorDefaultRuntimeMetrics = regexp.MustCompile(`/gc/gogc:percent|/gc/gomemlimit:bytes|/sched/gomaxprocs:threads`)

View File

@@ -66,7 +66,8 @@ func RuntimeMetricsToProm(d *metrics.Description) (string, string, string, bool)
name += "_total"
}
valid := model.IsValidMetricName(model.LabelValue(namespace + "_" + subsystem + "_" + name))
// Our current conversion moves to legacy naming, so use legacy validation.
valid := model.LegacyValidation.IsValidMetricName(namespace + "_" + subsystem + "_" + name)
switch d.Kind {
case metrics.KindUint64:
case metrics.KindFloat64:

View File

@@ -184,5 +184,6 @@ func validateLabelValues(vals []string, expectedNumberOfValues int) error {
}
func checkLabelName(l string) bool {
return model.LabelName(l).IsValid() && !strings.HasPrefix(l, reservedLabelPrefix)
//nolint:staticcheck // TODO: Don't use deprecated model.NameValidationScheme.
return model.NameValidationScheme.IsValidLabelName(l) && !strings.HasPrefix(l, reservedLabelPrefix)
}

View File

@@ -108,15 +108,23 @@ func BuildFQName(namespace, subsystem, name string) string {
if name == "" {
return ""
}
switch {
case namespace != "" && subsystem != "":
return strings.Join([]string{namespace, subsystem, name}, "_")
case namespace != "":
return strings.Join([]string{namespace, name}, "_")
case subsystem != "":
return strings.Join([]string{subsystem, name}, "_")
sb := strings.Builder{}
sb.Grow(len(namespace) + len(subsystem) + len(name) + 2)
if namespace != "" {
sb.WriteString(namespace)
sb.WriteString("_")
}
return name
if subsystem != "" {
sb.WriteString(subsystem)
sb.WriteString("_")
}
sb.WriteString(name)
return sb.String()
}
type invalidMetric struct {
@@ -178,21 +186,31 @@ func (m *withExemplarsMetric) Write(pb *dto.Metric) error {
case pb.Counter != nil:
pb.Counter.Exemplar = m.exemplars[len(m.exemplars)-1]
case pb.Histogram != nil:
h := pb.Histogram
for _, e := range m.exemplars {
// pb.Histogram.Bucket are sorted by UpperBound.
i := sort.Search(len(pb.Histogram.Bucket), func(i int) bool {
return pb.Histogram.Bucket[i].GetUpperBound() >= e.GetValue()
if (h.GetZeroThreshold() != 0 || h.GetZeroCount() != 0 ||
len(h.PositiveSpan) != 0 || len(h.NegativeSpan) != 0) &&
e.GetTimestamp() != nil {
h.Exemplars = append(h.Exemplars, e)
if len(h.Bucket) == 0 {
// Don't proceed to classic buckets if there are none.
continue
}
}
// h.Bucket are sorted by UpperBound.
i := sort.Search(len(h.Bucket), func(i int) bool {
return h.Bucket[i].GetUpperBound() >= e.GetValue()
})
if i < len(pb.Histogram.Bucket) {
pb.Histogram.Bucket[i].Exemplar = e
if i < len(h.Bucket) {
h.Bucket[i].Exemplar = e
} else {
// The +Inf bucket should be explicitly added if there is an exemplar for it, similar to non-const histogram logic in https://github.com/prometheus/client_golang/blob/main/prometheus/histogram.go#L357-L365.
b := &dto.Bucket{
CumulativeCount: proto.Uint64(pb.Histogram.GetSampleCount()),
CumulativeCount: proto.Uint64(h.GetSampleCount()),
UpperBound: proto.Float64(math.Inf(1)),
Exemplar: e,
}
pb.Histogram.Bucket = append(pb.Histogram.Bucket, b)
h.Bucket = append(h.Bucket, b)
}
}
default:
@@ -219,6 +237,7 @@ type Exemplar struct {
// Only last applicable exemplar is injected from the list.
// For example for Counter it means last exemplar is injected.
// For Histogram, it means last applicable exemplar for each bucket is injected.
// For a Native Histogram, all valid exemplars are injected.
//
// NewMetricWithExemplars works best with MustNewConstMetric and
// MustNewConstHistogram, see example.
@@ -234,7 +253,7 @@ func NewMetricWithExemplars(m Metric, exemplars ...Exemplar) (Metric, error) {
)
for i, e := range exemplars {
ts := e.Timestamp
if ts == (time.Time{}) {
if ts.IsZero() {
ts = now
}
exs[i], err = newExemplar(e.Value, ts, e.Labels)

View File

@@ -23,6 +23,7 @@ import (
type processCollector struct {
collectFn func(chan<- Metric)
describeFn func(chan<- *Desc)
pidFn func() (int, error)
reportErrors bool
cpuTotal *Desc
@@ -30,6 +31,7 @@ type processCollector struct {
vsize, maxVsize *Desc
rss *Desc
startTime *Desc
inBytes, outBytes *Desc
}
// ProcessCollectorOpts defines the behavior of a process metrics collector
@@ -100,6 +102,16 @@ func NewProcessCollector(opts ProcessCollectorOpts) Collector {
"Start time of the process since unix epoch in seconds.",
nil, nil,
),
inBytes: NewDesc(
ns+"process_network_receive_bytes_total",
"Number of bytes received by the process over the network.",
nil, nil,
),
outBytes: NewDesc(
ns+"process_network_transmit_bytes_total",
"Number of bytes sent by the process over the network.",
nil, nil,
),
}
if opts.PidFn == nil {
@@ -111,24 +123,23 @@ func NewProcessCollector(opts ProcessCollectorOpts) Collector {
// Set up process metric collection if supported by the runtime.
if canCollectProcess() {
c.collectFn = c.processCollect
c.describeFn = c.describe
} else {
c.collectFn = func(ch chan<- Metric) {
c.reportError(ch, nil, errors.New("process metrics not supported on this platform"))
}
c.collectFn = c.errorCollectFn
c.describeFn = c.errorDescribeFn
}
return c
}
// Describe returns all descriptions of the collector.
func (c *processCollector) Describe(ch chan<- *Desc) {
ch <- c.cpuTotal
ch <- c.openFDs
ch <- c.maxFDs
ch <- c.vsize
ch <- c.maxVsize
ch <- c.rss
ch <- c.startTime
func (c *processCollector) errorCollectFn(ch chan<- Metric) {
c.reportError(ch, nil, errors.New("process metrics not supported on this platform"))
}
func (c *processCollector) errorDescribeFn(ch chan<- *Desc) {
if c.reportErrors {
ch <- NewInvalidDesc(errors.New("process metrics not supported on this platform"))
}
}
// Collect returns the current state of all metrics of the collector.
@@ -136,6 +147,11 @@ func (c *processCollector) Collect(ch chan<- Metric) {
c.collectFn(ch)
}
// Describe returns all descriptions of the collector.
func (c *processCollector) Describe(ch chan<- *Desc) {
c.describeFn(ch)
}
func (c *processCollector) reportError(ch chan<- Metric, desc *Desc, err error) {
if !c.reportErrors {
return

View File

@@ -1,26 +0,0 @@
// Copyright 2019 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build js
// +build js
package prometheus
func canCollectProcess() bool {
return false
}
func (c *processCollector) processCollect(ch chan<- Metric) {
// noop on this platform
return
}

View File

@@ -1,66 +0,0 @@
// Copyright 2019 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !windows && !js && !wasip1
// +build !windows,!js,!wasip1
package prometheus
import (
"github.com/prometheus/procfs"
)
func canCollectProcess() bool {
_, err := procfs.NewDefaultFS()
return err == nil
}
func (c *processCollector) processCollect(ch chan<- Metric) {
pid, err := c.pidFn()
if err != nil {
c.reportError(ch, nil, err)
return
}
p, err := procfs.NewProc(pid)
if err != nil {
c.reportError(ch, nil, err)
return
}
if stat, err := p.Stat(); err == nil {
ch <- MustNewConstMetric(c.cpuTotal, CounterValue, stat.CPUTime())
ch <- MustNewConstMetric(c.vsize, GaugeValue, float64(stat.VirtualMemory()))
ch <- MustNewConstMetric(c.rss, GaugeValue, float64(stat.ResidentMemory()))
if startTime, err := stat.StartTime(); err == nil {
ch <- MustNewConstMetric(c.startTime, GaugeValue, startTime)
} else {
c.reportError(ch, c.startTime, err)
}
} else {
c.reportError(ch, nil, err)
}
if fds, err := p.FileDescriptorsLen(); err == nil {
ch <- MustNewConstMetric(c.openFDs, GaugeValue, float64(fds))
} else {
c.reportError(ch, c.openFDs, err)
}
if limits, err := p.Limits(); err == nil {
ch <- MustNewConstMetric(c.maxFDs, GaugeValue, float64(limits.OpenFiles))
ch <- MustNewConstMetric(c.maxVsize, GaugeValue, float64(limits.AddressSpace))
} else {
c.reportError(ch, nil, err)
}
}

View File

@@ -1,26 +0,0 @@
// Copyright 2023 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build wasip1
// +build wasip1
package prometheus
func canCollectProcess() bool {
return false
}
func (*processCollector) processCollect(chan<- Metric) {
// noop on this platform
return
}

View File

@@ -79,14 +79,10 @@ func getProcessHandleCount(handle windows.Handle) (uint32, error) {
}
func (c *processCollector) processCollect(ch chan<- Metric) {
h, err := windows.GetCurrentProcess()
if err != nil {
c.reportError(ch, nil, err)
return
}
h := windows.CurrentProcess()
var startTime, exitTime, kernelTime, userTime windows.Filetime
err = windows.GetProcessTimes(h, &startTime, &exitTime, &kernelTime, &userTime)
err := windows.GetProcessTimes(h, &startTime, &exitTime, &kernelTime, &userTime)
if err != nil {
c.reportError(ch, nil, err)
return
@@ -111,6 +107,19 @@ func (c *processCollector) processCollect(ch chan<- Metric) {
ch <- MustNewConstMetric(c.maxFDs, GaugeValue, float64(16*1024*1024)) // Windows has a hard-coded max limit, not per-process.
}
// describe returns all descriptions of the collector for windows.
// Ensure that this list of descriptors is kept in sync with the metrics collected
// in the processCollect method. Any changes to the metrics in processCollect
// (such as adding or removing metrics) should be reflected in this list of descriptors.
func (c *processCollector) describe(ch chan<- *Desc) {
ch <- c.cpuTotal
ch <- c.openFDs
ch <- c.maxFDs
ch <- c.vsize
ch <- c.rss
ch <- c.startTime
}
func fileTimeToSeconds(ft windows.Filetime) float64 {
return float64(uint64(ft.HighDateTime)<<32+uint64(ft.LowDateTime)) / 1e7
}

View File

@@ -76,6 +76,12 @@ func (r *responseWriterDelegator) Write(b []byte) (int, error) {
return n, err
}
// Unwrap lets http.ResponseController get the underlying http.ResponseWriter,
// by implementing the [rwUnwrapper](https://cs.opensource.google/go/go/+/refs/tags/go1.21.4:src/net/http/responsecontroller.go;l=42-44) interface.
func (r *responseWriterDelegator) Unwrap() http.ResponseWriter {
return r.ResponseWriter
}
type (
closeNotifierDelegator struct{ *responseWriterDelegator }
flusherDelegator struct{ *responseWriterDelegator }

View File

@@ -38,13 +38,14 @@ import (
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/prometheus/common/expfmt"
"github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp/internal"
)
const (
@@ -54,6 +55,24 @@ const (
processStartTimeHeader = "Process-Start-Time-Unix"
)
// Compression represents the content encodings handlers support for the HTTP
// responses.
type Compression string
const (
Identity Compression = "identity"
Gzip Compression = "gzip"
Zstd Compression = "zstd"
)
func defaultCompressionFormats() []Compression {
if internal.NewZstdWriter != nil {
return []Compression{Identity, Gzip, Zstd}
} else {
return []Compression{Identity, Gzip}
}
}
var gzipPool = sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
@@ -122,6 +141,18 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
}
}
// Select compression formats to offer based on default or user choice.
var compressions []string
if !opts.DisableCompression {
offers := defaultCompressionFormats()
if len(opts.OfferedCompressions) > 0 {
offers = opts.OfferedCompressions
}
for _, comp := range offers {
compressions = append(compressions, string(comp))
}
}
h := http.HandlerFunc(func(rsp http.ResponseWriter, req *http.Request) {
if !opts.ProcessStartTime.IsZero() {
rsp.Header().Set(processStartTimeHeader, strconv.FormatInt(opts.ProcessStartTime.Unix(), 10))
@@ -165,22 +196,30 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
} else {
contentType = expfmt.Negotiate(req.Header)
}
header := rsp.Header()
header.Set(contentTypeHeader, string(contentType))
rsp.Header().Set(contentTypeHeader, string(contentType))
w := io.Writer(rsp)
if !opts.DisableCompression && gzipAccepted(req.Header) {
header.Set(contentEncodingHeader, "gzip")
gz := gzipPool.Get().(*gzip.Writer)
defer gzipPool.Put(gz)
gz.Reset(w)
defer gz.Close()
w = gz
w, encodingHeader, closeWriter, err := negotiateEncodingWriter(req, rsp, compressions)
if err != nil {
if opts.ErrorLog != nil {
opts.ErrorLog.Println("error getting writer", err)
}
w = io.Writer(rsp)
encodingHeader = string(Identity)
}
enc := expfmt.NewEncoder(w, contentType)
defer closeWriter()
// Set Content-Encoding only when data is compressed
if encodingHeader != string(Identity) {
rsp.Header().Set(contentEncodingHeader, encodingHeader)
}
var enc expfmt.Encoder
if opts.EnableOpenMetricsTextCreatedSamples {
enc = expfmt.NewEncoder(w, contentType, expfmt.WithCreatedLines())
} else {
enc = expfmt.NewEncoder(w, contentType)
}
// handleError handles the error according to opts.ErrorHandling
// and returns true if we have to abort after the handling.
@@ -343,9 +382,19 @@ type HandlerOpts struct {
// no effect on the HTTP status code because ErrorHandling is set to
// ContinueOnError.
Registry prometheus.Registerer
// If DisableCompression is true, the handler will never compress the
// response, even if requested by the client.
// DisableCompression disables the response encoding (compression) and
// encoding negotiation. If true, the handler will
// never compress the response, even if requested
// by the client and the OfferedCompressions field is set.
DisableCompression bool
// OfferedCompressions is a set of encodings (compressions) handler will
// try to offer when negotiating with the client. This defaults to identity, gzip
// and zstd.
// NOTE: If handler can't agree with the client on the encodings or
// unsupported or empty encodings are set in OfferedCompressions,
// handler always fallbacks to no compression (identity), for
// compatibility reasons. In such cases ErrorLog will be used if set.
OfferedCompressions []Compression
// The number of concurrent HTTP requests is limited to
// MaxRequestsInFlight. Additional requests are responded to with 503
// Service Unavailable and a suitable message in the body. If
@@ -371,6 +420,21 @@ type HandlerOpts struct {
// (which changes the identity of the resulting series on the Prometheus
// server).
EnableOpenMetrics bool
// EnableOpenMetricsTextCreatedSamples specifies if this handler should add, extra, synthetic
// Created Timestamps for counters, histograms and summaries, which for the current
// version of OpenMetrics are defined as extra series with the same name and "_created"
// suffix. See also the OpenMetrics specification for more details
// https://github.com/prometheus/OpenMetrics/blob/v1.0.0/specification/OpenMetrics.md#counter-1
//
// Created timestamps are used to improve the accuracy of reset detection,
// but the way it's designed in OpenMetrics 1.0 it also dramatically increases cardinality
// if the scraper does not handle those metrics correctly (converting to created timestamp
// instead of leaving those series as-is). New OpenMetrics versions might improve
// this situation.
//
// Prometheus introduced the feature flag 'created-timestamp-zero-ingestion'
// in version 2.50.0 to handle this situation.
EnableOpenMetricsTextCreatedSamples bool
// ProcessStartTime allows setting process start timevalue that will be exposed
// with "Process-Start-Time-Unix" response header along with the metrics
// payload. This allow callers to have efficient transformations to cumulative
@@ -381,19 +445,6 @@ type HandlerOpts struct {
ProcessStartTime time.Time
}
// gzipAccepted returns whether the client will accept gzip-encoded content.
func gzipAccepted(header http.Header) bool {
a := header.Get(acceptEncodingHeader)
parts := strings.Split(a, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "gzip" || strings.HasPrefix(part, "gzip;") {
return true
}
}
return false
}
// httpError removes any content-encoding header and then calls http.Error with
// the provided error and http.StatusInternalServerError. Error contents is
// supposed to be uncompressed plain text. Same as with a plain http.Error, this
@@ -406,3 +457,36 @@ func httpError(rsp http.ResponseWriter, err error) {
http.StatusInternalServerError,
)
}
// negotiateEncodingWriter reads the Accept-Encoding header from a request and
// selects the right compression based on an allow-list of supported
// compressions. It returns a writer implementing the compression and an the
// correct value that the caller can set in the response header.
func negotiateEncodingWriter(r *http.Request, rw io.Writer, compressions []string) (_ io.Writer, encodingHeaderValue string, closeWriter func(), _ error) {
if len(compressions) == 0 {
return rw, string(Identity), func() {}, nil
}
// TODO(mrueg): Replace internal/github.com/gddo once https://github.com/golang/go/issues/19307 is implemented.
selected := httputil.NegotiateContentEncoding(r, compressions)
switch selected {
case "zstd":
if internal.NewZstdWriter == nil {
// The content encoding was not implemented yet.
return nil, "", func() {}, fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", selected, defaultCompressionFormats())
}
writer, closeWriter, err := internal.NewZstdWriter(rw)
return writer, selected, closeWriter, err
case "gzip":
gz := gzipPool.Get().(*gzip.Writer)
gz.Reset(rw)
return gz, selected, func() { _ = gz.Close(); gzipPool.Put(gz) }, nil
case "identity":
// This means the content is not compressed.
return rw, selected, func() {}, nil
default:
// The content encoding was not implemented yet.
return nil, "", func() {}, fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", selected, defaultCompressionFormats())
}
}

View File

@@ -392,7 +392,7 @@ func isLabelCurried(c prometheus.Collector, label string) bool {
func labels(code, method bool, reqMethod string, status int, extraMethods ...string) prometheus.Labels {
labels := prometheus.Labels{}
if !(code || method) {
if !code && !method {
return labels
}

View File

@@ -314,17 +314,18 @@ func (r *Registry) Register(c Collector) error {
if dimHash != desc.dimHash {
return fmt.Errorf("a previously registered descriptor with the same fully-qualified name as %s has different label names or a different help string", desc)
}
} else {
continue
}
// ...then check the new descriptors already seen.
if dimHash, exists := newDimHashesByName[desc.fqName]; exists {
if dimHash != desc.dimHash {
return fmt.Errorf("descriptors reported by collector have inconsistent label names or help strings for the same fully-qualified name, offender is %s", desc)
}
} else {
continue
}
newDimHashesByName[desc.fqName] = desc.dimHash
}
}
}
// A Collector yielding no Desc at all is considered unchecked.
if len(newDescIDs) == 0 {
r.uncheckedCollectors = append(r.uncheckedCollectors, c)

View File

@@ -243,6 +243,7 @@ func newSummary(desc *Desc, opts SummaryOpts, labelValues ...string) Summary {
s := &summary{
desc: desc,
now: opts.now,
objectives: opts.Objectives,
sortedObjectives: make([]float64, 0, len(opts.Objectives)),
@@ -280,6 +281,8 @@ type summary struct {
desc *Desc
now func() time.Time
objectives map[float64]float64
sortedObjectives []float64
@@ -307,7 +310,7 @@ func (s *summary) Observe(v float64) {
s.bufMtx.Lock()
defer s.bufMtx.Unlock()
now := time.Now()
now := s.now()
if now.After(s.hotBufExpTime) {
s.asyncFlush(now)
}
@@ -326,7 +329,7 @@ func (s *summary) Write(out *dto.Metric) error {
s.bufMtx.Lock()
s.mtx.Lock()
// Swap bufs even if hotBuf is empty to set new hotBufExpTime.
s.swapBufs(time.Now())
s.swapBufs(s.now())
s.bufMtx.Unlock()
s.flushColdBuf()
@@ -783,3 +786,45 @@ func MustNewConstSummary(
}
return m
}
// NewConstSummaryWithCreatedTimestamp does the same thing as NewConstSummary but sets the created timestamp.
func NewConstSummaryWithCreatedTimestamp(
desc *Desc,
count uint64,
sum float64,
quantiles map[float64]float64,
ct time.Time,
labelValues ...string,
) (Metric, error) {
if desc.err != nil {
return nil, desc.err
}
if err := validateLabelValues(labelValues, len(desc.variableLabels.names)); err != nil {
return nil, err
}
return &constSummary{
desc: desc,
count: count,
sum: sum,
quantiles: quantiles,
labelPairs: MakeLabelPairs(desc, labelValues),
createdTs: timestamppb.New(ct),
}, nil
}
// MustNewConstSummaryWithCreatedTimestamp is a version of NewConstSummaryWithCreatedTimestamp that panics where
// NewConstSummaryWithCreatedTimestamp would have returned an error.
func MustNewConstSummaryWithCreatedTimestamp(
desc *Desc,
count uint64,
sum float64,
quantiles map[float64]float64,
ct time.Time,
labelValues ...string,
) Metric {
m, err := NewConstSummaryWithCreatedTimestamp(desc, count, sum, quantiles, ct, labelValues...)
if err != nil {
panic(err)
}
return m
}

View File

@@ -79,7 +79,7 @@ func (m *MetricVec) DeleteLabelValues(lvs ...string) bool {
return false
}
return m.metricMap.deleteByHashWithLabelValues(h, lvs, m.curry)
return m.deleteByHashWithLabelValues(h, lvs, m.curry)
}
// Delete deletes the metric where the variable labels are the same as those
@@ -101,7 +101,7 @@ func (m *MetricVec) Delete(labels Labels) bool {
return false
}
return m.metricMap.deleteByHashWithLabels(h, labels, m.curry)
return m.deleteByHashWithLabels(h, labels, m.curry)
}
// DeletePartialMatch deletes all metrics where the variable labels contain all of those
@@ -114,7 +114,7 @@ func (m *MetricVec) DeletePartialMatch(labels Labels) int {
labels, closer := constrainLabels(m.desc, labels)
defer closer()
return m.metricMap.deleteByLabels(labels, m.curry)
return m.deleteByLabels(labels, m.curry)
}
// Without explicit forwarding of Describe, Collect, Reset, those methods won't
@@ -216,7 +216,7 @@ func (m *MetricVec) GetMetricWithLabelValues(lvs ...string) (Metric, error) {
return nil, err
}
return m.metricMap.getOrCreateMetricWithLabelValues(h, lvs, m.curry), nil
return m.getOrCreateMetricWithLabelValues(h, lvs, m.curry), nil
}
// GetMetricWith returns the Metric for the given Labels map (the label names
@@ -244,7 +244,7 @@ func (m *MetricVec) GetMetricWith(labels Labels) (Metric, error) {
return nil, err
}
return m.metricMap.getOrCreateMetricWithLabels(h, labels, m.curry), nil
return m.getOrCreateMetricWithLabels(h, labels, m.curry), nil
}
func (m *MetricVec) hashLabelValues(vals []string) (uint64, error) {
@@ -507,7 +507,7 @@ func (m *metricMap) getOrCreateMetricWithLabelValues(
return metric
}
// getOrCreateMetricWithLabelValues retrieves the metric by hash and label value
// getOrCreateMetricWithLabels retrieves the metric by hash and label value
// or creates it and returns the new one.
//
// This function holds the mutex.

View File

@@ -63,7 +63,7 @@ func WrapRegistererWith(labels Labels, reg Registerer) Registerer {
// metric names that are standardized across applications, as that would break
// horizontal monitoring, for example the metrics provided by the Go collector
// (see NewGoCollector) and the process collector (see NewProcessCollector). (In
// fact, those metrics are already prefixed with go_ or process_,
// fact, those metrics are already prefixed with "go_" or "process_",
// respectively.)
//
// Conflicts between Collectors registered through the original Registerer with
@@ -78,6 +78,40 @@ func WrapRegistererWithPrefix(prefix string, reg Registerer) Registerer {
}
}
// WrapCollectorWith returns a Collector wrapping the provided Collector. The
// wrapped Collector will add the provided Labels to all Metrics it collects (as
// ConstLabels). The Metrics collected by the unmodified Collector must not
// duplicate any of those labels.
//
// WrapCollectorWith can be useful to work with multiple instances of a third
// party library that does not expose enough flexibility on the lifecycle of its
// registered metrics.
// For example, let's say you have a foo.New(reg Registerer) constructor that
// registers metrics but never unregisters them, and you want to create multiple
// instances of foo.Foo with different labels.
// The way to achieve that, is to create a new Registry, pass it to foo.New,
// then use WrapCollectorWith to wrap that Registry with the desired labels and
// register that as a collector in your main Registry.
// Then you can un-register the wrapped collector effectively un-registering the
// metrics registered by foo.New.
func WrapCollectorWith(labels Labels, c Collector) Collector {
return &wrappingCollector{
wrappedCollector: c,
labels: labels,
}
}
// WrapCollectorWithPrefix returns a Collector wrapping the provided Collector. The
// wrapped Collector will add the provided prefix to the name of all Metrics it collects.
//
// See the documentation of WrapCollectorWith for more details on the use case.
func WrapCollectorWithPrefix(prefix string, c Collector) Collector {
return &wrappingCollector{
wrappedCollector: c,
prefix: prefix,
}
}
type wrappingRegisterer struct {
wrappedRegisterer Registerer
prefix string

View File

@@ -45,7 +45,7 @@ func ResponseFormat(h http.Header) Format {
mediatype, params, err := mime.ParseMediaType(ct)
if err != nil {
return fmtUnknown
return FmtUnknown
}
const textType = "text/plain"
@@ -53,36 +53,51 @@ func ResponseFormat(h http.Header) Format {
switch mediatype {
case ProtoType:
if p, ok := params["proto"]; ok && p != ProtoProtocol {
return fmtUnknown
return FmtUnknown
}
if e, ok := params["encoding"]; ok && e != "delimited" {
return fmtUnknown
return FmtUnknown
}
return fmtProtoDelim
return FmtProtoDelim
case textType:
if v, ok := params["version"]; ok && v != TextVersion {
return fmtUnknown
return FmtUnknown
}
return fmtText
return FmtText
}
return fmtUnknown
return FmtUnknown
}
// NewDecoder returns a new decoder based on the given input format.
// If the input format does not imply otherwise, a text format decoder is returned.
// NewDecoder returns a new decoder based on the given input format. Metric
// names are validated based on the provided Format -- if the format requires
// escaping, raditional Prometheues validity checking is used. Otherwise, names
// are checked for UTF-8 validity. Supported formats include delimited protobuf
// and Prometheus text format. For historical reasons, this decoder fallbacks
// to classic text decoding for any other format. This decoder does not fully
// support OpenMetrics although it may often succeed due to the similarities
// between the formats. This decoder may not support the latest features of
// Prometheus text format and is not intended for high-performance applications.
// See: https://github.com/prometheus/common/issues/812
func NewDecoder(r io.Reader, format Format) Decoder {
scheme := model.LegacyValidation
if format.ToEscapingScheme() == model.NoEscaping {
scheme = model.UTF8Validation
}
switch format.FormatType() {
case TypeProtoDelim:
return &protoDecoder{r: r}
return &protoDecoder{r: bufio.NewReader(r), s: scheme}
case TypeProtoText, TypeProtoCompact:
return &errDecoder{err: fmt.Errorf("format %s not supported for decoding", format)}
}
return &textDecoder{r: r}
return &textDecoder{r: r, s: scheme}
}
// protoDecoder implements the Decoder interface for protocol buffers.
type protoDecoder struct {
r io.Reader
r protodelim.Reader
s model.ValidationScheme
}
// Decode implements the Decoder interface.
@@ -90,10 +105,10 @@ func (d *protoDecoder) Decode(v *dto.MetricFamily) error {
opts := protodelim.UnmarshalOptions{
MaxSize: -1,
}
if err := opts.UnmarshalFrom(bufio.NewReader(d.r), v); err != nil {
if err := opts.UnmarshalFrom(d.r, v); err != nil {
return err
}
if !model.IsValidMetricName(model.LabelValue(v.GetName())) {
if !d.s.IsValidMetricName(v.GetName()) {
return fmt.Errorf("invalid metric name %q", v.GetName())
}
for _, m := range v.GetMetric() {
@@ -107,7 +122,7 @@ func (d *protoDecoder) Decode(v *dto.MetricFamily) error {
if !model.LabelValue(l.GetValue()).IsValid() {
return fmt.Errorf("invalid label value %q", l.GetValue())
}
if !model.LabelName(l.GetName()).IsValid() {
if !d.s.IsValidLabelName(l.GetName()) {
return fmt.Errorf("invalid label name %q", l.GetName())
}
}
@@ -115,10 +130,20 @@ func (d *protoDecoder) Decode(v *dto.MetricFamily) error {
return nil
}
// errDecoder is an error-state decoder that always returns the same error.
type errDecoder struct {
err error
}
func (d *errDecoder) Decode(*dto.MetricFamily) error {
return d.err
}
// textDecoder implements the Decoder interface for the text protocol.
type textDecoder struct {
r io.Reader
fams map[string]*dto.MetricFamily
s model.ValidationScheme
err error
}
@@ -126,7 +151,7 @@ type textDecoder struct {
func (d *textDecoder) Decode(v *dto.MetricFamily) error {
if d.err == nil {
// Read all metrics in one shot.
var p TextParser
p := NewTextParser(d.s)
d.fams, d.err = p.TextToMetricFamilies(d.r)
// If we don't get an error, store io.EOF for the end.
if d.err == nil {
@@ -195,7 +220,7 @@ func extractSamples(f *dto.MetricFamily, o *DecodeOptions) (model.Vector, error)
return extractSummary(o, f), nil
case dto.MetricType_UNTYPED:
return extractUntyped(o, f), nil
case dto.MetricType_HISTOGRAM:
case dto.MetricType_HISTOGRAM, dto.MetricType_GAUGE_HISTOGRAM:
return extractHistogram(o, f), nil
}
return nil, fmt.Errorf("expfmt.extractSamples: unknown metric family type %v", f.GetType())
@@ -378,9 +403,13 @@ func extractHistogram(o *DecodeOptions, f *dto.MetricFamily) model.Vector {
infSeen = true
}
v := q.GetCumulativeCountFloat()
if v <= 0 {
v = float64(q.GetCumulativeCount())
}
samples = append(samples, &model.Sample{
Metric: model.Metric(lset),
Value: model.SampleValue(q.GetCumulativeCount()),
Value: model.SampleValue(v),
Timestamp: timestamp,
})
}
@@ -403,9 +432,13 @@ func extractHistogram(o *DecodeOptions, f *dto.MetricFamily) model.Vector {
}
lset[model.MetricNameLabel] = model.LabelValue(f.GetName() + "_count")
v := m.Histogram.GetSampleCountFloat()
if v <= 0 {
v = float64(m.Histogram.GetSampleCount())
}
count := &model.Sample{
Metric: model.Metric(lset),
Value: model.SampleValue(m.Histogram.GetSampleCount()),
Value: model.SampleValue(v),
Timestamp: timestamp,
}
samples = append(samples, count)

View File

@@ -18,13 +18,12 @@ import (
"io"
"net/http"
"github.com/munnerz/goautoneg"
dto "github.com/prometheus/client_model/go"
"google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/encoding/prototext"
"github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg"
"github.com/prometheus/common/model"
dto "github.com/prometheus/client_model/go"
)
// Encoder types encode metric families into an underlying wire protocol.
@@ -60,14 +59,14 @@ func (ec encoderCloser) Close() error {
// appropriate accepted type is found, FmtText is returned (which is the
// Prometheus text format). This function will never negotiate FmtOpenMetrics,
// as the support is still experimental. To include the option to negotiate
// FmtOpenMetrics, use NegotiateOpenMetrics.
// FmtOpenMetrics, use NegotiateIncludingOpenMetrics.
func Negotiate(h http.Header) Format {
escapingScheme := Format(fmt.Sprintf("; escaping=%s", Format(model.NameEscapingScheme.String())))
for _, ac := range goautoneg.ParseAccept(h.Get(hdrAccept)) {
if escapeParam := ac.Params[model.EscapingKey]; escapeParam != "" {
switch Format(escapeParam) {
case model.AllowUTF8, model.EscapeUnderscores, model.EscapeDots, model.EscapeValues:
escapingScheme = Format(fmt.Sprintf("; escaping=%s", escapeParam))
escapingScheme = Format("; escaping=" + escapeParam)
default:
// If the escaping parameter is unknown, ignore it.
}
@@ -76,18 +75,18 @@ func Negotiate(h http.Header) Format {
if ac.Type+"/"+ac.SubType == ProtoType && ac.Params["proto"] == ProtoProtocol {
switch ac.Params["encoding"] {
case "delimited":
return fmtProtoDelim + escapingScheme
return FmtProtoDelim + escapingScheme
case "text":
return fmtProtoText + escapingScheme
return FmtProtoText + escapingScheme
case "compact-text":
return fmtProtoCompact + escapingScheme
return FmtProtoCompact + escapingScheme
}
}
if ac.Type == "text" && ac.SubType == "plain" && (ver == TextVersion || ver == "") {
return fmtText + escapingScheme
return FmtText + escapingScheme
}
}
return fmtText + escapingScheme
return FmtText + escapingScheme
}
// NegotiateIncludingOpenMetrics works like Negotiate but includes
@@ -100,7 +99,7 @@ func NegotiateIncludingOpenMetrics(h http.Header) Format {
if escapeParam := ac.Params[model.EscapingKey]; escapeParam != "" {
switch Format(escapeParam) {
case model.AllowUTF8, model.EscapeUnderscores, model.EscapeDots, model.EscapeValues:
escapingScheme = Format(fmt.Sprintf("; escaping=%s", escapeParam))
escapingScheme = Format("; escaping=" + escapeParam)
default:
// If the escaping parameter is unknown, ignore it.
}
@@ -109,26 +108,26 @@ func NegotiateIncludingOpenMetrics(h http.Header) Format {
if ac.Type+"/"+ac.SubType == ProtoType && ac.Params["proto"] == ProtoProtocol {
switch ac.Params["encoding"] {
case "delimited":
return fmtProtoDelim + escapingScheme
return FmtProtoDelim + escapingScheme
case "text":
return fmtProtoText + escapingScheme
return FmtProtoText + escapingScheme
case "compact-text":
return fmtProtoCompact + escapingScheme
return FmtProtoCompact + escapingScheme
}
}
if ac.Type == "text" && ac.SubType == "plain" && (ver == TextVersion || ver == "") {
return fmtText + escapingScheme
return FmtText + escapingScheme
}
if ac.Type+"/"+ac.SubType == OpenMetricsType && (ver == OpenMetricsVersion_0_0_1 || ver == OpenMetricsVersion_1_0_0 || ver == "") {
switch ver {
case OpenMetricsVersion_1_0_0:
return fmtOpenMetrics_1_0_0 + escapingScheme
return FmtOpenMetrics_1_0_0 + escapingScheme
default:
return fmtOpenMetrics_0_0_1 + escapingScheme
return FmtOpenMetrics_0_0_1 + escapingScheme
}
}
}
return fmtText + escapingScheme
return FmtText + escapingScheme
}
// NewEncoder returns a new encoder based on content type negotiation. All
@@ -152,7 +151,7 @@ func NewEncoder(w io.Writer, format Format, options ...EncoderOption) Encoder {
case TypeProtoDelim:
return encoderCloser{
encode: func(v *dto.MetricFamily) error {
_, err := protodelim.MarshalTo(w, v)
_, err := protodelim.MarshalTo(w, model.EscapeMetricFamily(v, escapingScheme))
return err
},
close: func() error { return nil },

View File

@@ -15,6 +15,7 @@
package expfmt
import (
"errors"
"strings"
"github.com/prometheus/common/model"
@@ -34,21 +35,33 @@ const (
TextVersion = "0.0.4"
ProtoType = `application/vnd.google.protobuf`
ProtoProtocol = `io.prometheus.client.MetricFamily`
protoFmt = ProtoType + "; proto=" + ProtoProtocol + ";"
// Deprecated: Use expfmt.NewFormat(expfmt.TypeProtoCompact) instead.
ProtoFmt = ProtoType + "; proto=" + ProtoProtocol + ";"
OpenMetricsType = `application/openmetrics-text`
//nolint:revive // Allow for underscores.
OpenMetricsVersion_0_0_1 = "0.0.1"
//nolint:revive // Allow for underscores.
OpenMetricsVersion_1_0_0 = "1.0.0"
// The Content-Type values for the different wire protocols. Note that these
// values are now unexported. If code was relying on comparisons to these
// constants, instead use FormatType().
fmtUnknown Format = `<unknown>`
fmtText Format = `text/plain; version=` + TextVersion + `; charset=utf-8`
fmtProtoDelim Format = protoFmt + ` encoding=delimited`
fmtProtoText Format = protoFmt + ` encoding=text`
fmtProtoCompact Format = protoFmt + ` encoding=compact-text`
fmtOpenMetrics_1_0_0 Format = OpenMetricsType + `; version=` + OpenMetricsVersion_1_0_0 + `; charset=utf-8`
fmtOpenMetrics_0_0_1 Format = OpenMetricsType + `; version=` + OpenMetricsVersion_0_0_1 + `; charset=utf-8`
// The Content-Type values for the different wire protocols. Do not do direct
// comparisons to these constants, instead use the comparison functions.
//
// Deprecated: Use expfmt.NewFormat(expfmt.TypeUnknown) instead.
FmtUnknown Format = `<unknown>`
// Deprecated: Use expfmt.NewFormat(expfmt.TypeTextPlain) instead.
FmtText Format = `text/plain; version=` + TextVersion + `; charset=utf-8`
// Deprecated: Use expfmt.NewFormat(expfmt.TypeProtoDelim) instead.
FmtProtoDelim Format = ProtoFmt + ` encoding=delimited`
// Deprecated: Use expfmt.NewFormat(expfmt.TypeProtoText) instead.
FmtProtoText Format = ProtoFmt + ` encoding=text`
// Deprecated: Use expfmt.NewFormat(expfmt.TypeProtoCompact) instead.
FmtProtoCompact Format = ProtoFmt + ` encoding=compact-text`
// Deprecated: Use expfmt.NewFormat(expfmt.TypeOpenMetrics) instead.
//nolint:revive // Allow for underscores.
FmtOpenMetrics_1_0_0 Format = OpenMetricsType + `; version=` + OpenMetricsVersion_1_0_0 + `; charset=utf-8`
// Deprecated: Use expfmt.NewFormat(expfmt.TypeOpenMetrics) instead.
//nolint:revive // Allow for underscores.
FmtOpenMetrics_0_0_1 Format = OpenMetricsType + `; version=` + OpenMetricsVersion_0_0_1 + `; charset=utf-8`
)
const (
@@ -63,7 +76,7 @@ const (
type FormatType int
const (
TypeUnknown = iota
TypeUnknown FormatType = iota
TypeProtoCompact
TypeProtoDelim
TypeProtoText
@@ -73,31 +86,63 @@ const (
// NewFormat generates a new Format from the type provided. Mostly used for
// tests, most Formats should be generated as part of content negotiation in
// encode.go.
// encode.go. If a type has more than one version, the latest version will be
// returned.
func NewFormat(t FormatType) Format {
switch t {
case TypeProtoCompact:
return fmtProtoCompact
return FmtProtoCompact
case TypeProtoDelim:
return fmtProtoDelim
return FmtProtoDelim
case TypeProtoText:
return fmtProtoText
return FmtProtoText
case TypeTextPlain:
return fmtText
return FmtText
case TypeOpenMetrics:
return fmtOpenMetrics_1_0_0
return FmtOpenMetrics_1_0_0
default:
return fmtUnknown
return FmtUnknown
}
}
// NewOpenMetricsFormat generates a new OpenMetrics format matching the
// specified version number.
func NewOpenMetricsFormat(version string) (Format, error) {
if version == OpenMetricsVersion_0_0_1 {
return FmtOpenMetrics_0_0_1, nil
}
if version == OpenMetricsVersion_1_0_0 {
return FmtOpenMetrics_1_0_0, nil
}
return FmtUnknown, errors.New("unknown open metrics version string")
}
// WithEscapingScheme returns a copy of Format with the specified escaping
// scheme appended to the end. If an escaping scheme already exists it is
// removed.
func (f Format) WithEscapingScheme(s model.EscapingScheme) Format {
var terms []string
for _, p := range strings.Split(string(f), ";") {
toks := strings.Split(p, "=")
if len(toks) != 2 {
trimmed := strings.TrimSpace(p)
if len(trimmed) > 0 {
terms = append(terms, trimmed)
}
continue
}
key := strings.TrimSpace(toks[0])
if key != model.EscapingKey {
terms = append(terms, strings.TrimSpace(p))
}
}
terms = append(terms, model.EscapingKey+"="+s.String())
return Format(strings.Join(terms, "; "))
}
// FormatType deduces an overall FormatType for the given format.
func (f Format) FormatType() FormatType {
toks := strings.Split(string(f), ";")
if len(toks) < 2 {
return TypeUnknown
}
params := make(map[string]string)
for i, t := range toks {
if i == 0 {
@@ -148,8 +193,8 @@ func (f Format) FormatType() FormatType {
// Format contains a escaping=allow-utf-8 term, it will select NoEscaping. If a valid
// "escaping" term exists, that will be used. Otherwise, the global default will
// be returned.
func (format Format) ToEscapingScheme() model.EscapingScheme {
for _, p := range strings.Split(string(format), ";") {
func (f Format) ToEscapingScheme() model.EscapingScheme {
for _, p := range strings.Split(string(f), ";") {
toks := strings.Split(p, "=")
if len(toks) != 2 {
continue

View File

@@ -13,11 +13,14 @@
// Build only when actually fuzzing
//go:build gofuzz
// +build gofuzz
package expfmt
import "bytes"
import (
"bytes"
"github.com/prometheus/common/model"
)
// Fuzz text metric parser with with github.com/dvyukov/go-fuzz:
//
@@ -26,9 +29,8 @@ import "bytes"
//
// Further input samples should go in the folder fuzz/corpus.
func Fuzz(in []byte) int {
parser := TextParser{}
parser := NewTextParser(model.UTF8Validation)
_, err := parser.TextToMetricFamilies(bytes.NewReader(in))
if err != nil {
return 0
}

View File

@@ -22,11 +22,10 @@ import (
"strconv"
"strings"
dto "github.com/prometheus/client_model/go"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/prometheus/common/model"
dto "github.com/prometheus/client_model/go"
)
type encoderOption struct {
@@ -38,7 +37,7 @@ type EncoderOption func(*encoderOption)
// WithCreatedLines is an EncoderOption that configures the OpenMetrics encoder
// to include _created lines (See
// https://github.com/OpenObservability/OpenMetrics/blob/main/specification/OpenMetrics.md#counter-1).
// https://github.com/prometheus/OpenMetrics/blob/v1.0.0/specification/OpenMetrics.md#counter-1).
// Created timestamps can improve the accuracy of series reset detection, but
// come with a bandwidth cost.
//
@@ -102,7 +101,7 @@ func WithUnit() EncoderOption {
//
// - According to the OM specs, the `# UNIT` line is optional, but if populated,
// the unit has to be present in the metric name as its suffix:
// (see https://github.com/OpenObservability/OpenMetrics/blob/main/specification/OpenMetrics.md#unit).
// (see https://github.com/prometheus/OpenMetrics/blob/v1.0.0/specification/OpenMetrics.md#unit).
// However, in order to accommodate any potential scenario where such a change in the
// metric name is not desirable, the users are here given the choice of either explicitly
// opt in, in case they wish for the unit to be included in the output AND in the metric name
@@ -152,8 +151,8 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
if metricType == dto.MetricType_COUNTER && strings.HasSuffix(compliantName, "_total") {
compliantName = name[:len(name)-6]
}
if toOM.withUnit && in.Unit != nil && !strings.HasSuffix(compliantName, fmt.Sprintf("_%s", *in.Unit)) {
compliantName = compliantName + fmt.Sprintf("_%s", *in.Unit)
if toOM.withUnit && in.Unit != nil && !strings.HasSuffix(compliantName, "_"+*in.Unit) {
compliantName = compliantName + "_" + *in.Unit
}
// Comments, first HELP, then TYPE.
@@ -161,38 +160,38 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
n, err = w.WriteString("# HELP ")
written += n
if err != nil {
return
return written, err
}
n, err = writeName(w, compliantName)
written += n
if err != nil {
return
return written, err
}
err = w.WriteByte(' ')
written++
if err != nil {
return
return written, err
}
n, err = writeEscapedString(w, *in.Help, true)
written += n
if err != nil {
return
return written, err
}
err = w.WriteByte('\n')
written++
if err != nil {
return
return written, err
}
}
n, err = w.WriteString("# TYPE ")
written += n
if err != nil {
return
return written, err
}
n, err = writeName(w, compliantName)
written += n
if err != nil {
return
return written, err
}
switch metricType {
case dto.MetricType_COUNTER:
@@ -209,51 +208,53 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
n, err = w.WriteString(" unknown\n")
case dto.MetricType_HISTOGRAM:
n, err = w.WriteString(" histogram\n")
case dto.MetricType_GAUGE_HISTOGRAM:
n, err = w.WriteString(" gaugehistogram\n")
default:
return written, fmt.Errorf("unknown metric type %s", metricType.String())
}
written += n
if err != nil {
return
return written, err
}
if toOM.withUnit && in.Unit != nil {
n, err = w.WriteString("# UNIT ")
written += n
if err != nil {
return
return written, err
}
n, err = writeName(w, compliantName)
written += n
if err != nil {
return
return written, err
}
err = w.WriteByte(' ')
written++
if err != nil {
return
return written, err
}
n, err = writeEscapedString(w, *in.Unit, true)
written += n
if err != nil {
return
return written, err
}
err = w.WriteByte('\n')
written++
if err != nil {
return
return written, err
}
}
var createdTsBytesWritten int
// Finally the samples, one line for each.
if metricType == dto.MetricType_COUNTER && strings.HasSuffix(name, "_total") {
compliantName += "_total"
}
for _, metric := range in.Metric {
switch metricType {
case dto.MetricType_COUNTER:
if strings.HasSuffix(name, "_total") {
compliantName = compliantName + "_total"
}
if metric.Counter == nil {
return written, fmt.Errorf(
"expected counter in metric %s %s", compliantName, metric,
@@ -305,7 +306,7 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
)
written += n
if err != nil {
return
return written, err
}
}
n, err = writeOpenMetricsSample(
@@ -315,7 +316,7 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
)
written += n
if err != nil {
return
return written, err
}
n, err = writeOpenMetricsSample(
w, compliantName, "_count", metric, "", 0,
@@ -326,7 +327,7 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
createdTsBytesWritten, err = writeOpenMetricsCreated(w, compliantName, "", metric, "", 0, metric.Summary.GetCreatedTimestamp())
n += createdTsBytesWritten
}
case dto.MetricType_HISTOGRAM:
case dto.MetricType_HISTOGRAM, dto.MetricType_GAUGE_HISTOGRAM:
if metric.Histogram == nil {
return written, fmt.Errorf(
"expected histogram in metric %s %s", compliantName, metric,
@@ -334,6 +335,12 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
}
infSeen := false
for _, b := range metric.Histogram.Bucket {
if b.GetCumulativeCountFloat() > 0 {
return written, fmt.Errorf(
"OpenMetrics v1.0 does not support float histogram %s %s",
compliantName, metric,
)
}
n, err = writeOpenMetricsSample(
w, compliantName, "_bucket", metric,
model.BucketLabel, b.GetUpperBound(),
@@ -342,7 +349,7 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
)
written += n
if err != nil {
return
return written, err
}
if math.IsInf(b.GetUpperBound(), +1) {
infSeen = true
@@ -355,9 +362,12 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
0, metric.Histogram.GetSampleCount(), true,
nil,
)
// We do not check for a float sample count here
// because we will check for it below (and error
// out if needed).
written += n
if err != nil {
return
return written, err
}
}
n, err = writeOpenMetricsSample(
@@ -367,7 +377,13 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
)
written += n
if err != nil {
return
return written, err
}
if metric.Histogram.GetSampleCountFloat() > 0 {
return written, fmt.Errorf(
"OpenMetrics v1.0 does not support float histogram %s %s",
compliantName, metric,
)
}
n, err = writeOpenMetricsSample(
w, compliantName, "_count", metric, "", 0,
@@ -385,10 +401,10 @@ func MetricFamilyToOpenMetrics(out io.Writer, in *dto.MetricFamily, options ...E
}
written += n
if err != nil {
return
return written, err
}
}
return
return written, err
}
// FinalizeOpenMetrics writes the final `# EOF\n` line required by OpenMetrics.
@@ -477,7 +493,7 @@ func writeOpenMetricsNameAndLabelPairs(
if name != "" {
// If the name does not pass the legacy validity check, we must put the
// metric name inside the braces, quoted.
if !model.IsValidLegacyMetricName(model.LabelValue(name)) {
if !model.LegacyValidation.IsValidMetricName(name) {
metricInsideBraces = true
err := w.WriteByte(separator)
written++
@@ -641,11 +657,11 @@ func writeExemplar(w enhancedWriter, e *dto.Exemplar) (int, error) {
if err != nil {
return written, err
}
err = (*e).Timestamp.CheckValid()
err = e.Timestamp.CheckValid()
if err != nil {
return written, err
}
ts := (*e).Timestamp.AsTime()
ts := e.Timestamp.AsTime()
// TODO(beorn7): Format this directly from components of ts to
// avoid overflow/underflow and precision issues of the float
// conversion.

View File

@@ -22,9 +22,9 @@ import (
"strings"
"sync"
"github.com/prometheus/common/model"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/model"
)
// enhancedWriter has all the enhanced write functions needed here. bufio.Writer
@@ -108,38 +108,38 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e
n, err = w.WriteString("# HELP ")
written += n
if err != nil {
return
return written, err
}
n, err = writeName(w, name)
written += n
if err != nil {
return
return written, err
}
err = w.WriteByte(' ')
written++
if err != nil {
return
return written, err
}
n, err = writeEscapedString(w, *in.Help, false)
written += n
if err != nil {
return
return written, err
}
err = w.WriteByte('\n')
written++
if err != nil {
return
return written, err
}
}
n, err = w.WriteString("# TYPE ")
written += n
if err != nil {
return
return written, err
}
n, err = writeName(w, name)
written += n
if err != nil {
return
return written, err
}
metricType := in.GetType()
switch metricType {
@@ -151,14 +151,17 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e
n, err = w.WriteString(" summary\n")
case dto.MetricType_UNTYPED:
n, err = w.WriteString(" untyped\n")
case dto.MetricType_HISTOGRAM:
case dto.MetricType_HISTOGRAM, dto.MetricType_GAUGE_HISTOGRAM:
// The classic Prometheus text format has no notion of a gauge
// histogram. We render a gauge histogram in the same way as a
// regular histogram.
n, err = w.WriteString(" histogram\n")
default:
return written, fmt.Errorf("unknown metric type %s", metricType.String())
}
written += n
if err != nil {
return
return written, err
}
// Finally the samples, one line for each.
@@ -208,7 +211,7 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e
)
written += n
if err != nil {
return
return written, err
}
}
n, err = writeSample(
@@ -217,13 +220,13 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e
)
written += n
if err != nil {
return
return written, err
}
n, err = writeSample(
w, name, "_count", metric, "", 0,
float64(metric.Summary.GetSampleCount()),
)
case dto.MetricType_HISTOGRAM:
case dto.MetricType_HISTOGRAM, dto.MetricType_GAUGE_HISTOGRAM:
if metric.Histogram == nil {
return written, fmt.Errorf(
"expected histogram in metric %s %s", name, metric,
@@ -231,28 +234,36 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e
}
infSeen := false
for _, b := range metric.Histogram.Bucket {
v := b.GetCumulativeCountFloat()
if v == 0 {
v = float64(b.GetCumulativeCount())
}
n, err = writeSample(
w, name, "_bucket", metric,
model.BucketLabel, b.GetUpperBound(),
float64(b.GetCumulativeCount()),
v,
)
written += n
if err != nil {
return
return written, err
}
if math.IsInf(b.GetUpperBound(), +1) {
infSeen = true
}
}
if !infSeen {
v := metric.Histogram.GetSampleCountFloat()
if v == 0 {
v = float64(metric.Histogram.GetSampleCount())
}
n, err = writeSample(
w, name, "_bucket", metric,
model.BucketLabel, math.Inf(+1),
float64(metric.Histogram.GetSampleCount()),
v,
)
written += n
if err != nil {
return
return written, err
}
}
n, err = writeSample(
@@ -261,12 +272,13 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e
)
written += n
if err != nil {
return
return written, err
}
n, err = writeSample(
w, name, "_count", metric, "", 0,
float64(metric.Histogram.GetSampleCount()),
)
v := metric.Histogram.GetSampleCountFloat()
if v == 0 {
v = float64(metric.Histogram.GetSampleCount())
}
n, err = writeSample(w, name, "_count", metric, "", 0, v)
default:
return written, fmt.Errorf(
"unexpected type in metric %s %s", name, metric,
@@ -274,10 +286,10 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e
}
written += n
if err != nil {
return
return written, err
}
}
return
return written, err
}
// writeSample writes a single sample in text format to w, given the metric
@@ -354,7 +366,7 @@ func writeNameAndLabelPairs(
if name != "" {
// If the name does not pass the legacy validity check, we must put the
// metric name inside the braces.
if !model.IsValidLegacyMetricName(model.LabelValue(name)) {
if !model.LegacyValidation.IsValidMetricName(name) {
metricInsideBraces = true
err := w.WriteByte(separator)
written++
@@ -498,7 +510,7 @@ func writeInt(w enhancedWriter, i int64) (int, error) {
// writeName writes a string as-is if it complies with the legacy naming
// scheme, or escapes it in double quotes if not.
func writeName(w enhancedWriter, name string) (int, error) {
if model.IsValidLegacyMetricName(model.LabelValue(name)) {
if model.LegacyValidation.IsValidMetricName(name) {
return w.WriteString(name)
}
var written int

View File

@@ -22,9 +22,9 @@ import (
"math"
"strconv"
"strings"
"unicode/utf8"
dto "github.com/prometheus/client_model/go"
"google.golang.org/protobuf/proto"
"github.com/prometheus/common/model"
@@ -48,8 +48,10 @@ func (e ParseError) Error() string {
return fmt.Sprintf("text format parsing error in line %d: %s", e.Line, e.Msg)
}
// TextParser is used to parse the simple and flat text-based exchange format. Its
// zero value is ready to use.
// TextParser is used to parse the simple and flat text-based exchange format.
//
// TextParser instances must be created with NewTextParser, the zero value of
// TextParser is invalid.
type TextParser struct {
metricFamiliesByName map[string]*dto.MetricFamily
buf *bufio.Reader // Where the parsed input is read through.
@@ -60,6 +62,7 @@ type TextParser struct {
currentMF *dto.MetricFamily
currentMetric *dto.Metric
currentLabelPair *dto.LabelPair
currentLabelPairs []*dto.LabelPair // Temporarily stores label pairs while parsing a metric line.
// The remaining member variables are only used for summaries/histograms.
currentLabels map[string]string // All labels including '__name__' but excluding 'quantile'/'le'
@@ -74,6 +77,17 @@ type TextParser struct {
// count and sum of that summary/histogram.
currentIsSummaryCount, currentIsSummarySum bool
currentIsHistogramCount, currentIsHistogramSum bool
// These indicate if the metric name from the current line being parsed is inside
// braces and if that metric name was found respectively.
currentMetricIsInsideBraces, currentMetricInsideBracesIsPresent bool
// scheme sets the desired ValidationScheme for names. Defaults to the invalid
// UnsetValidation.
scheme model.ValidationScheme
}
// NewTextParser returns a new TextParser with the provided nameValidationScheme.
func NewTextParser(nameValidationScheme model.ValidationScheme) TextParser {
return TextParser{scheme: nameValidationScheme}
}
// TextToMetricFamilies reads 'in' as the simple and flat text-based exchange
@@ -117,11 +131,47 @@ func (p *TextParser) TextToMetricFamilies(in io.Reader) (map[string]*dto.MetricF
if p.err != nil && errors.Is(p.err, io.EOF) {
p.parseError("unexpected end of input stream")
}
for _, histogramMetric := range p.histograms {
normalizeHistogram(histogramMetric.GetHistogram())
}
return p.metricFamiliesByName, p.err
}
// normalizeHistogram makes sure that all the buckets and the count in each
// histogram is either completely float or completely integer.
func normalizeHistogram(histogram *dto.Histogram) {
if histogram == nil {
return
}
anyFloats := false
if histogram.GetSampleCountFloat() != 0 {
anyFloats = true
} else {
for _, b := range histogram.GetBucket() {
if b.GetCumulativeCountFloat() != 0 {
anyFloats = true
break
}
}
}
if !anyFloats {
return
}
if histogram.GetSampleCountFloat() == 0 {
histogram.SampleCountFloat = proto.Float64(float64(histogram.GetSampleCount()))
histogram.SampleCount = nil
}
for _, b := range histogram.GetBucket() {
if b.GetCumulativeCountFloat() == 0 {
b.CumulativeCountFloat = proto.Float64(float64(b.GetCumulativeCount()))
b.CumulativeCount = nil
}
}
}
func (p *TextParser) reset(in io.Reader) {
p.metricFamiliesByName = map[string]*dto.MetricFamily{}
p.currentLabelPairs = nil
if p.buf == nil {
p.buf = bufio.NewReader(in)
} else {
@@ -137,12 +187,15 @@ func (p *TextParser) reset(in io.Reader) {
}
p.currentQuantile = math.NaN()
p.currentBucket = math.NaN()
p.currentMF = nil
}
// startOfLine represents the state where the next byte read from p.buf is the
// start of a line (or whitespace leading up to it).
func (p *TextParser) startOfLine() stateFn {
p.lineCount++
p.currentMetricIsInsideBraces = false
p.currentMetricInsideBracesIsPresent = false
if p.skipBlankTab(); p.err != nil {
// This is the only place that we expect to see io.EOF,
// which is not an error but the signal that we are done.
@@ -158,6 +211,9 @@ func (p *TextParser) startOfLine() stateFn {
return p.startComment
case '\n':
return p.startOfLine // Empty line, start the next one.
case '{':
p.currentMetricIsInsideBraces = true
return p.readingLabels
}
return p.readingMetricName
}
@@ -206,6 +262,9 @@ func (p *TextParser) startComment() stateFn {
return nil
}
p.setOrCreateCurrentMF()
if p.err != nil {
return nil
}
if p.skipBlankTab(); p.err != nil {
return nil // Unexpected end of input.
}
@@ -234,6 +293,9 @@ func (p *TextParser) readingMetricName() stateFn {
return nil
}
p.setOrCreateCurrentMF()
if p.err != nil {
return nil
}
// Now is the time to fix the type if it hasn't happened yet.
if p.currentMF.Type == nil {
p.currentMF.Type = dto.MetricType_UNTYPED.Enum()
@@ -256,7 +318,9 @@ func (p *TextParser) readingLabels() stateFn {
// Summaries/histograms are special. We have to reset the
// currentLabels map, currentQuantile and currentBucket before starting to
// read labels.
if p.currentMF.GetType() == dto.MetricType_SUMMARY || p.currentMF.GetType() == dto.MetricType_HISTOGRAM {
if p.currentMF.GetType() == dto.MetricType_SUMMARY ||
p.currentMF.GetType() == dto.MetricType_HISTOGRAM ||
p.currentMF.GetType() == dto.MetricType_GAUGE_HISTOGRAM {
p.currentLabels = map[string]string{}
p.currentLabels[string(model.MetricNameLabel)] = p.currentMF.GetName()
p.currentQuantile = math.NaN()
@@ -275,6 +339,8 @@ func (p *TextParser) startLabelName() stateFn {
return nil // Unexpected end of input.
}
if p.currentByte == '}' {
p.currentMetric.Label = append(p.currentMetric.Label, p.currentLabelPairs...)
p.currentLabelPairs = nil
if p.skipBlankTab(); p.err != nil {
return nil // Unexpected end of input.
}
@@ -287,34 +353,81 @@ func (p *TextParser) startLabelName() stateFn {
p.parseError(fmt.Sprintf("invalid label name for metric %q", p.currentMF.GetName()))
return nil
}
p.currentLabelPair = &dto.LabelPair{Name: proto.String(p.currentToken.String())}
if p.currentLabelPair.GetName() == string(model.MetricNameLabel) {
p.parseError(fmt.Sprintf("label name %q is reserved", model.MetricNameLabel))
return nil
}
// Special summary/histogram treatment. Don't add 'quantile' and 'le'
// labels to 'real' labels.
if !(p.currentMF.GetType() == dto.MetricType_SUMMARY && p.currentLabelPair.GetName() == model.QuantileLabel) &&
!(p.currentMF.GetType() == dto.MetricType_HISTOGRAM && p.currentLabelPair.GetName() == model.BucketLabel) {
p.currentMetric.Label = append(p.currentMetric.Label, p.currentLabelPair)
}
if p.skipBlankTabIfCurrentBlankTab(); p.err != nil {
return nil // Unexpected end of input.
}
if p.currentByte != '=' {
p.parseError(fmt.Sprintf("expected '=' after label name, found %q", p.currentByte))
if p.currentMetricIsInsideBraces {
if p.currentMetricInsideBracesIsPresent {
p.parseError(fmt.Sprintf("multiple metric names for metric %q", p.currentMF.GetName()))
return nil
}
switch p.currentByte {
case ',':
p.setOrCreateCurrentMF()
if p.err != nil {
return nil
}
if p.currentMF.Type == nil {
p.currentMF.Type = dto.MetricType_UNTYPED.Enum()
}
p.currentMetric = &dto.Metric{}
p.currentMetricInsideBracesIsPresent = true
return p.startLabelName
case '}':
p.setOrCreateCurrentMF()
if p.err != nil {
p.currentLabelPairs = nil
return nil
}
if p.currentMF.Type == nil {
p.currentMF.Type = dto.MetricType_UNTYPED.Enum()
}
p.currentMetric = &dto.Metric{}
p.currentMetric.Label = append(p.currentMetric.Label, p.currentLabelPairs...)
p.currentLabelPairs = nil
if p.skipBlankTab(); p.err != nil {
return nil // Unexpected end of input.
}
return p.readingValue
default:
p.parseError(fmt.Sprintf("unexpected end of metric name %q", p.currentByte))
return nil
}
}
p.parseError(fmt.Sprintf("expected '=' after label name, found %q", p.currentByte))
p.currentLabelPairs = nil
return nil
}
p.currentLabelPair = &dto.LabelPair{Name: proto.String(p.currentToken.String())}
if p.currentLabelPair.GetName() == string(model.MetricNameLabel) {
p.parseError(fmt.Sprintf("label name %q is reserved", model.MetricNameLabel))
p.currentLabelPairs = nil
return nil
}
if !p.scheme.IsValidLabelName(p.currentLabelPair.GetName()) {
p.parseError(fmt.Sprintf("invalid label name %q", p.currentLabelPair.GetName()))
p.currentLabelPairs = nil
return nil
}
// Special summary/histogram treatment. Don't add 'quantile' and 'le'
// labels to 'real' labels.
if (p.currentMF.GetType() != dto.MetricType_SUMMARY || p.currentLabelPair.GetName() != model.QuantileLabel) &&
((p.currentMF.GetType() != dto.MetricType_HISTOGRAM &&
p.currentMF.GetType() != dto.MetricType_GAUGE_HISTOGRAM) ||
p.currentLabelPair.GetName() != model.BucketLabel) {
p.currentLabelPairs = append(p.currentLabelPairs, p.currentLabelPair)
}
// Check for duplicate label names.
labels := make(map[string]struct{})
for _, l := range p.currentMetric.Label {
for _, l := range p.currentLabelPairs {
lName := l.GetName()
if _, exists := labels[lName]; !exists {
labels[lName] = struct{}{}
} else {
if _, exists := labels[lName]; exists {
p.parseError(fmt.Sprintf("duplicate label names for metric %q", p.currentMF.GetName()))
p.currentLabelPairs = nil
return nil
}
labels[lName] = struct{}{}
}
return p.startLabelValue
}
@@ -345,6 +458,7 @@ func (p *TextParser) startLabelValue() stateFn {
if p.currentQuantile, p.err = parseFloat(p.currentLabelPair.GetValue()); p.err != nil {
// Create a more helpful error message.
p.parseError(fmt.Sprintf("expected float as value for 'quantile' label, got %q", p.currentLabelPair.GetValue()))
p.currentLabelPairs = nil
return nil
}
} else {
@@ -352,7 +466,7 @@ func (p *TextParser) startLabelValue() stateFn {
}
}
// Similar special treatment of histograms.
if p.currentMF.GetType() == dto.MetricType_HISTOGRAM {
if p.currentMF.GetType() == dto.MetricType_HISTOGRAM || p.currentMF.GetType() == dto.MetricType_GAUGE_HISTOGRAM {
if p.currentLabelPair.GetName() == model.BucketLabel {
if p.currentBucket, p.err = parseFloat(p.currentLabelPair.GetValue()); p.err != nil {
// Create a more helpful error message.
@@ -371,12 +485,19 @@ func (p *TextParser) startLabelValue() stateFn {
return p.startLabelName
case '}':
if p.currentMF == nil {
p.parseError("invalid metric name")
return nil
}
p.currentMetric.Label = append(p.currentMetric.Label, p.currentLabelPairs...)
p.currentLabelPairs = nil
if p.skipBlankTab(); p.err != nil {
return nil // Unexpected end of input.
}
return p.readingValue
default:
p.parseError(fmt.Sprintf("unexpected end of label value %q", p.currentLabelPair.GetValue()))
p.currentLabelPairs = nil
return nil
}
}
@@ -387,7 +508,8 @@ func (p *TextParser) readingValue() stateFn {
// When we are here, we have read all the labels, so for the
// special case of a summary/histogram, we can finally find out
// if the metric already exists.
if p.currentMF.GetType() == dto.MetricType_SUMMARY {
switch p.currentMF.GetType() {
case dto.MetricType_SUMMARY:
signature := model.LabelsToSignature(p.currentLabels)
if summary := p.summaries[signature]; summary != nil {
p.currentMetric = summary
@@ -395,7 +517,7 @@ func (p *TextParser) readingValue() stateFn {
p.summaries[signature] = p.currentMetric
p.currentMF.Metric = append(p.currentMF.Metric, p.currentMetric)
}
} else if p.currentMF.GetType() == dto.MetricType_HISTOGRAM {
case dto.MetricType_HISTOGRAM, dto.MetricType_GAUGE_HISTOGRAM:
signature := model.LabelsToSignature(p.currentLabels)
if histogram := p.histograms[signature]; histogram != nil {
p.currentMetric = histogram
@@ -403,7 +525,7 @@ func (p *TextParser) readingValue() stateFn {
p.histograms[signature] = p.currentMetric
p.currentMF.Metric = append(p.currentMF.Metric, p.currentMetric)
}
} else {
default:
p.currentMF.Metric = append(p.currentMF.Metric, p.currentMetric)
}
if p.readTokenUntilWhitespace(); p.err != nil {
@@ -441,24 +563,38 @@ func (p *TextParser) readingValue() stateFn {
},
)
}
case dto.MetricType_HISTOGRAM:
case dto.MetricType_HISTOGRAM, dto.MetricType_GAUGE_HISTOGRAM:
// *sigh*
if p.currentMetric.Histogram == nil {
p.currentMetric.Histogram = &dto.Histogram{}
}
switch {
case p.currentIsHistogramCount:
p.currentMetric.Histogram.SampleCount = proto.Uint64(uint64(value))
if uintValue := uint64(value); value == float64(uintValue) {
p.currentMetric.Histogram.SampleCount = proto.Uint64(uintValue)
} else {
if value < 0 {
p.parseError(fmt.Sprintf("negative count for histogram %q", p.currentMF.GetName()))
return nil
}
p.currentMetric.Histogram.SampleCountFloat = proto.Float64(value)
}
case p.currentIsHistogramSum:
p.currentMetric.Histogram.SampleSum = proto.Float64(value)
case !math.IsNaN(p.currentBucket):
p.currentMetric.Histogram.Bucket = append(
p.currentMetric.Histogram.Bucket,
&dto.Bucket{
b := &dto.Bucket{
UpperBound: proto.Float64(p.currentBucket),
CumulativeCount: proto.Uint64(uint64(value)),
},
)
}
if uintValue := uint64(value); value == float64(uintValue) {
b.CumulativeCount = proto.Uint64(uintValue)
} else {
if value < 0 {
p.parseError(fmt.Sprintf("negative bucket population for histogram %q", p.currentMF.GetName()))
return nil
}
b.CumulativeCountFloat = proto.Float64(value)
}
p.currentMetric.Histogram.Bucket = append(p.currentMetric.Histogram.Bucket, b)
}
default:
p.err = fmt.Errorf("unexpected type for metric name %q", p.currentMF.GetName())
@@ -521,11 +657,19 @@ func (p *TextParser) readingType() stateFn {
if p.readTokenUntilNewline(false); p.err != nil {
return nil // Unexpected end of input.
}
metricType, ok := dto.MetricType_value[strings.ToUpper(p.currentToken.String())]
typ := strings.ToUpper(p.currentToken.String()) // Tolerate any combination of upper and lower case.
metricType, ok := dto.MetricType_value[typ] // Tolerate "gauge_histogram" (not originally part of the text format).
if !ok {
// We also want to tolerate "gaugehistogram" to mark a gauge
// histogram, because that string is used in OpenMetrics. Note,
// however, that gauge histograms do not officially exist in the
// classic text format.
if typ != "GAUGEHISTOGRAM" {
p.parseError(fmt.Sprintf("unknown metric type %q", p.currentToken.String()))
return nil
}
metricType = int32(dto.MetricType_GAUGE_HISTOGRAM)
}
p.currentMF.Type = dto.MetricType(metricType).Enum()
return p.startOfLine
}
@@ -585,6 +729,8 @@ func (p *TextParser) readTokenUntilNewline(recognizeEscapeSequence bool) {
p.currentToken.WriteByte(p.currentByte)
case 'n':
p.currentToken.WriteByte('\n')
case '"':
p.currentToken.WriteByte('"')
default:
p.parseError(fmt.Sprintf("invalid escape sequence '\\%c'", p.currentByte))
return
@@ -610,13 +756,45 @@ func (p *TextParser) readTokenUntilNewline(recognizeEscapeSequence bool) {
// but not into p.currentToken.
func (p *TextParser) readTokenAsMetricName() {
p.currentToken.Reset()
// A UTF-8 metric name must be quoted and may have escaped characters.
quoted := false
escaped := false
if !isValidMetricNameStart(p.currentByte) {
return
}
for {
for p.err == nil {
if escaped {
switch p.currentByte {
case '\\':
p.currentToken.WriteByte(p.currentByte)
case 'n':
p.currentToken.WriteByte('\n')
case '"':
p.currentToken.WriteByte('"')
default:
p.parseError(fmt.Sprintf("invalid escape sequence '\\%c'", p.currentByte))
return
}
escaped = false
} else {
switch p.currentByte {
case '"':
quoted = !quoted
if !quoted {
p.currentByte, p.err = p.buf.ReadByte()
if p.err != nil || !isValidMetricNameContinuation(p.currentByte) {
return
}
case '\n':
p.parseError(fmt.Sprintf("metric name %q contains unescaped new-line", p.currentToken.String()))
return
case '\\':
escaped = true
default:
p.currentToken.WriteByte(p.currentByte)
}
}
p.currentByte, p.err = p.buf.ReadByte()
if !isValidMetricNameContinuation(p.currentByte, quoted) || (!quoted && p.currentByte == ' ') {
return
}
}
@@ -628,13 +806,45 @@ func (p *TextParser) readTokenAsMetricName() {
// but not into p.currentToken.
func (p *TextParser) readTokenAsLabelName() {
p.currentToken.Reset()
// A UTF-8 label name must be quoted and may have escaped characters.
quoted := false
escaped := false
if !isValidLabelNameStart(p.currentByte) {
return
}
for {
for p.err == nil {
if escaped {
switch p.currentByte {
case '\\':
p.currentToken.WriteByte(p.currentByte)
case 'n':
p.currentToken.WriteByte('\n')
case '"':
p.currentToken.WriteByte('"')
default:
p.parseError(fmt.Sprintf("invalid escape sequence '\\%c'", p.currentByte))
return
}
escaped = false
} else {
switch p.currentByte {
case '"':
quoted = !quoted
if !quoted {
p.currentByte, p.err = p.buf.ReadByte()
if p.err != nil || !isValidLabelNameContinuation(p.currentByte) {
return
}
case '\n':
p.parseError(fmt.Sprintf("label name %q contains unescaped new-line", p.currentToken.String()))
return
case '\\':
escaped = true
default:
p.currentToken.WriteByte(p.currentByte)
}
}
p.currentByte, p.err = p.buf.ReadByte()
if !isValidLabelNameContinuation(p.currentByte, quoted) || (!quoted && p.currentByte == '=') {
return
}
}
@@ -660,6 +870,7 @@ func (p *TextParser) readTokenAsLabelValue() {
p.currentToken.WriteByte('\n')
default:
p.parseError(fmt.Sprintf("invalid escape sequence '\\%c'", p.currentByte))
p.currentLabelPairs = nil
return
}
escaped = false
@@ -685,6 +896,10 @@ func (p *TextParser) setOrCreateCurrentMF() {
p.currentIsHistogramCount = false
p.currentIsHistogramSum = false
name := p.currentToken.String()
if !p.scheme.IsValidMetricName(name) {
p.parseError(fmt.Sprintf("invalid metric name %q", name))
return
}
if p.currentMF = p.metricFamiliesByName[name]; p.currentMF != nil {
return
}
@@ -703,7 +918,8 @@ func (p *TextParser) setOrCreateCurrentMF() {
}
histogramName := histogramMetricName(name)
if p.currentMF = p.metricFamiliesByName[histogramName]; p.currentMF != nil {
if p.currentMF.GetType() == dto.MetricType_HISTOGRAM {
if p.currentMF.GetType() == dto.MetricType_HISTOGRAM ||
p.currentMF.GetType() == dto.MetricType_GAUGE_HISTOGRAM {
if isCount(name) {
p.currentIsHistogramCount = true
}
@@ -718,19 +934,19 @@ func (p *TextParser) setOrCreateCurrentMF() {
}
func isValidLabelNameStart(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || b == '_'
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || b == '_' || b == '"'
}
func isValidLabelNameContinuation(b byte) bool {
return isValidLabelNameStart(b) || (b >= '0' && b <= '9')
func isValidLabelNameContinuation(b byte, quoted bool) bool {
return isValidLabelNameStart(b) || (b >= '0' && b <= '9') || (quoted && utf8.ValidString(string(b)))
}
func isValidMetricNameStart(b byte) bool {
return isValidLabelNameStart(b) || b == ':'
}
func isValidMetricNameContinuation(b byte) bool {
return isValidLabelNameContinuation(b) || b == ':'
func isValidMetricNameContinuation(b byte, quoted bool) bool {
return isValidLabelNameContinuation(b, quoted) || b == ':'
}
func isBlankOrTab(b byte) bool {
@@ -775,7 +991,7 @@ func histogramMetricName(name string) string {
func parseFloat(s string) (float64, error) {
if strings.ContainsAny(s, "pP_") {
return 0, fmt.Errorf("unsupported character in float")
return 0, errors.New("unsupported character in float")
}
return strconv.ParseFloat(s, 64)
}

View File

@@ -1,67 +0,0 @@
PACKAGE
package goautoneg
import "bitbucket.org/ww/goautoneg"
HTTP Content-Type Autonegotiation.
The functions in this package implement the behaviour specified in
http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
Copyright (c) 2011, Open Knowledge Foundation Ltd.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
Neither the name of the Open Knowledge Foundation Ltd. nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
FUNCTIONS
func Negotiate(header string, alternatives []string) (content_type string)
Negotiate the most appropriate content_type given the accept header
and a list of alternatives.
func ParseAccept(header string) (accept []Accept)
Parse an Accept Header string returning a sorted list
of clauses
TYPES
type Accept struct {
Type, SubType string
Q float32
Params map[string]string
}
Structure to represent a clause in an HTTP Accept Header
SUBDIRECTORIES
.hg

View File

@@ -1,160 +0,0 @@
/*
Copyright (c) 2011, Open Knowledge Foundation Ltd.
All rights reserved.
HTTP Content-Type Autonegotiation.
The functions in this package implement the behaviour specified in
http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
Neither the name of the Open Knowledge Foundation Ltd. nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package goautoneg
import (
"sort"
"strconv"
"strings"
)
// Structure to represent a clause in an HTTP Accept Header
type Accept struct {
Type, SubType string
Q float64
Params map[string]string
}
// For internal use, so that we can use the sort interface
type accept_slice []Accept
func (accept accept_slice) Len() int {
slice := []Accept(accept)
return len(slice)
}
func (accept accept_slice) Less(i, j int) bool {
slice := []Accept(accept)
ai, aj := slice[i], slice[j]
if ai.Q > aj.Q {
return true
}
if ai.Type != "*" && aj.Type == "*" {
return true
}
if ai.SubType != "*" && aj.SubType == "*" {
return true
}
return false
}
func (accept accept_slice) Swap(i, j int) {
slice := []Accept(accept)
slice[i], slice[j] = slice[j], slice[i]
}
// Parse an Accept Header string returning a sorted list
// of clauses
func ParseAccept(header string) (accept []Accept) {
parts := strings.Split(header, ",")
accept = make([]Accept, 0, len(parts))
for _, part := range parts {
part := strings.Trim(part, " ")
a := Accept{}
a.Params = make(map[string]string)
a.Q = 1.0
mrp := strings.Split(part, ";")
media_range := mrp[0]
sp := strings.Split(media_range, "/")
a.Type = strings.Trim(sp[0], " ")
switch {
case len(sp) == 1 && a.Type == "*":
a.SubType = "*"
case len(sp) == 2:
a.SubType = strings.Trim(sp[1], " ")
default:
continue
}
if len(mrp) == 1 {
accept = append(accept, a)
continue
}
for _, param := range mrp[1:] {
sp := strings.SplitN(param, "=", 2)
if len(sp) != 2 {
continue
}
token := strings.Trim(sp[0], " ")
if token == "q" {
a.Q, _ = strconv.ParseFloat(sp[1], 32)
} else {
a.Params[token] = strings.Trim(sp[1], " ")
}
}
accept = append(accept, a)
}
slice := accept_slice(accept)
sort.Sort(slice)
return
}
// Negotiate the most appropriate content_type given the accept header
// and a list of alternatives.
func Negotiate(header string, alternatives []string) (content_type string) {
asp := make([][]string, 0, len(alternatives))
for _, ctype := range alternatives {
asp = append(asp, strings.SplitN(ctype, "/", 2))
}
for _, clause := range ParseAccept(header) {
for i, ctsp := range asp {
if clause.Type == ctsp[0] && clause.SubType == ctsp[1] {
content_type = alternatives[i]
return
}
if clause.Type == ctsp[0] && clause.SubType == "*" {
content_type = alternatives[i]
return
}
if clause.Type == "*" && clause.SubType == "*" {
content_type = alternatives[i]
return
}
}
}
return
}

View File

@@ -14,6 +14,7 @@
package model
import (
"errors"
"fmt"
"time"
)
@@ -64,7 +65,7 @@ func (a *Alert) Resolved() bool {
return a.ResolvedAt(time.Now())
}
// ResolvedAt returns true off the activity interval ended before
// ResolvedAt returns true iff the activity interval ended before
// the given timestamp.
func (a *Alert) ResolvedAt(ts time.Time) bool {
if a.EndsAt.IsZero() {
@@ -75,7 +76,12 @@ func (a *Alert) ResolvedAt(ts time.Time) bool {
// Status returns the status of the alert.
func (a *Alert) Status() AlertStatus {
if a.Resolved() {
return a.StatusAt(time.Now())
}
// StatusAt returns the status of the alert at the given timestamp.
func (a *Alert) StatusAt(ts time.Time) AlertStatus {
if a.ResolvedAt(ts) {
return AlertResolved
}
return AlertFiring
@@ -84,16 +90,16 @@ func (a *Alert) Status() AlertStatus {
// Validate checks whether the alert data is inconsistent.
func (a *Alert) Validate() error {
if a.StartsAt.IsZero() {
return fmt.Errorf("start time missing")
return errors.New("start time missing")
}
if !a.EndsAt.IsZero() && a.EndsAt.Before(a.StartsAt) {
return fmt.Errorf("start time must be before end time")
return errors.New("start time must be before end time")
}
if err := a.Labels.Validate(); err != nil {
return fmt.Errorf("invalid label set: %w", err)
}
if len(a.Labels) == 0 {
return fmt.Errorf("at least one label pair required")
return errors.New("at least one label pair required")
}
if err := a.Annotations.Validate(); err != nil {
return fmt.Errorf("invalid annotations: %w", err)
@@ -127,6 +133,17 @@ func (as Alerts) HasFiring() bool {
return false
}
// HasFiringAt returns true iff one of the alerts is not resolved
// at the time ts.
func (as Alerts) HasFiringAt(ts time.Time) bool {
for _, a := range as {
if !a.ResolvedAt(ts) {
return true
}
}
return false
}
// Status returns StatusFiring iff at least one of the alerts is firing.
func (as Alerts) Status() AlertStatus {
if as.HasFiring() {
@@ -134,3 +151,12 @@ func (as Alerts) Status() AlertStatus {
}
return AlertResolved
}
// StatusAt returns StatusFiring iff at least one of the alerts is firing
// at the time ts.
func (as Alerts) StatusAt(ts time.Time) AlertStatus {
if as.HasFiringAt(ts) {
return AlertFiring
}
return AlertResolved
}

View File

@@ -22,7 +22,7 @@ import (
)
const (
// AlertNameLabel is the name of the label containing the an alert's name.
// AlertNameLabel is the name of the label containing the alert's name.
AlertNameLabel = "alertname"
// ExportedLabelPrefix is the prefix to prepend to the label names present in
@@ -32,6 +32,12 @@ const (
// MetricNameLabel is the label name indicating the metric name of a
// timeseries.
MetricNameLabel = "__name__"
// MetricTypeLabel is the label name indicating the metric type of
// timeseries as per the PROM-39 proposal.
MetricTypeLabel = "__type__"
// MetricUnitLabel is the label name indicating the metric unit of
// timeseries as per the PROM-39 proposal.
MetricUnitLabel = "__unit__"
// SchemeLabel is the name of the label that holds the scheme on which to
// scrape a target.
@@ -97,27 +103,24 @@ var LabelNameRE = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
// therewith.
type LabelName string
// IsValid returns true iff name matches the pattern of LabelNameRE for legacy
// names, and iff it's valid UTF-8 if NameValidationScheme is set to
// UTF8Validation. For the legacy matching, it does not use LabelNameRE for the
// check but a much faster hardcoded implementation.
// IsValid returns true iff the name matches the pattern of LabelNameRE when
// NameValidationScheme is set to LegacyValidation, or valid UTF-8 if
// NameValidationScheme is set to UTF8Validation.
//
// Deprecated: This method should not be used and may be removed in the future.
// Use [ValidationScheme.IsValidLabelName] instead.
func (ln LabelName) IsValid() bool {
if len(ln) == 0 {
return false
return NameValidationScheme.IsValidLabelName(string(ln))
}
switch NameValidationScheme {
case LegacyValidation:
for i, b := range ln {
if !((b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || b == '_' || (b >= '0' && b <= '9' && i > 0)) {
return false
}
}
case UTF8Validation:
return utf8.ValidString(string(ln))
default:
panic(fmt.Sprintf("Invalid name validation scheme requested: %d", NameValidationScheme))
}
return true
// IsValidLegacy returns true iff name matches the pattern of LabelNameRE for
// legacy names. It does not use LabelNameRE for the check but a much faster
// hardcoded implementation.
//
// Deprecated: This method should not be used and may be removed in the future.
// Use [LegacyValidation.IsValidLabelName] instead.
func (ln LabelName) IsValidLegacy() bool {
return LegacyValidation.IsValidLabelName(string(ln))
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.

Some files were not shown because too many files have changed in this diff Show More