mlflow-client-go/api.go

131 lines
2.5 KiB
Go

package mlflow
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"time"
)
const ViewTypeAll = "ALL"
const ViewTypeActive = "ACTIVE_ONLY"
const ViewTypeDeleted = "DELETED_ONLY"
type ViewType string
type MlflowApiError struct {
ErrorCode string `json:"error_code,omitempty"`
ErrorDesc string `json:"message,omitempty"`
}
func (e MlflowApiError) Error() string {
return fmt.Sprint("Request Error " + string(e.ErrorCode) + ": " + e.ErrorDesc)
}
type Timestamp int64
func (t *Timestamp) Time() time.Time {
return time.Unix(int64(*t), 0)
}
func (t *Timestamp) Set(time time.Time) Timestamp {
return Timestamp(time.Unix())
}
type TimestampMs int64
func (t *TimestampMs) Time() time.Time {
return time.Unix(int64(*t/1000), 0)
}
func (t *TimestampMs) Set(time time.Time) TimestampMs {
return TimestampMs(time.Unix() * 1000)
}
func apiReadReply(resp *http.Response, err error, ref any) error {
if err != nil {
fmt.Println(err.Error())
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println(err.Error())
return err
}
if resp.StatusCode != 200 && resp.StatusCode != 201 {
// Error code return
err := MlflowApiError{}
err2 := json.Unmarshal(body, &err)
if err2 != nil {
fmt.Println("Error when parsing reply: ", err2.Error(), "\n", string(body))
return err2
}
return err
}
err = json.Unmarshal(body, &ref)
//fmt.Println(string(body))
if err != nil {
return err
}
return nil
}
type typeTag struct {
FieldName string
Omitempty bool
}
func parseJsonTag(tag reflect.StructTag, FieldName string) typeTag {
rv := typeTag{}
tag1 := strings.Split(string(tag), ":")
if tag1[0] == "json" {
tag2 := strings.Split(strings.ReplaceAll(tag1[1], `"`, ""), ",")
rv.FieldName = tag2[0]
rv.Omitempty = (len(tag2) == 2 && tag2[1] == "omitempty")
}
if rv.FieldName == "" {
rv.FieldName = FieldName
}
return rv
}
func UrlEncode(s any) string {
rv := ""
rvCount := 0
v := reflect.ValueOf(s)
typeOfV := v.Type()
values := make([]interface{}, v.NumField())
for i := 0; i < v.NumField(); i++ {
values[i] = v.Field(i).Interface()
name := typeOfV.Field(i).Name
tag := parseJsonTag(typeOfV.Field(i).Tag, name)
val := v.Field(i)
if tag.FieldName != "-" && (!tag.Omitempty || fmt.Sprint(val) != "") {
if rvCount > 0 {
rv += "&"
} else {
rv += "?"
}
rv += fmt.Sprintf("%s=%s", tag.FieldName, fmt.Sprint(val))
rvCount++
}
}
return url.QueryEscape(rv)
}