151 lines
2.9 KiB
Go
151 lines
2.9 KiB
Go
package mlflow
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const ViewTypeAll = "ALL"
|
|
const ViewTypeActive = "ACTIVE_ONLY"
|
|
const ViewTypeDeleted = "DELETED_ONLY"
|
|
|
|
type ViewType string
|
|
|
|
type MlflowApiError struct {
|
|
ErrorCode string `json:"error_code,omitempty"`
|
|
ErrorDesc string `json:"message,omitempty"`
|
|
}
|
|
|
|
func (e MlflowApiError) Error() string {
|
|
return fmt.Sprint("Request Error " + string(e.ErrorCode) + ": " + e.ErrorDesc)
|
|
}
|
|
|
|
type MlflowApiErrorNotFound MlflowApiError
|
|
|
|
func (e MlflowApiErrorNotFound) Error() string {
|
|
return fmt.Sprint("Request Error 404 " + string(e.ErrorCode) + ": " + e.ErrorDesc)
|
|
}
|
|
|
|
type MlflowApiError400 MlflowApiError
|
|
|
|
func (e MlflowApiError400) Error() string {
|
|
return fmt.Sprint("Request Error 400 " + string(e.ErrorCode) + ": " + e.ErrorDesc)
|
|
}
|
|
|
|
type Timestamp int64
|
|
|
|
func (t *Timestamp) Time() time.Time {
|
|
return time.Unix(int64(*t), 0)
|
|
}
|
|
|
|
func (t *Timestamp) Set(time time.Time) Timestamp {
|
|
return Timestamp(time.Unix())
|
|
}
|
|
|
|
type TimestampMs int64
|
|
|
|
func (t *TimestampMs) Time() time.Time {
|
|
return time.Unix(int64(*t/1000), 0)
|
|
}
|
|
|
|
func (t *TimestampMs) Set(time time.Time) TimestampMs {
|
|
return TimestampMs(time.Unix() * 1000)
|
|
}
|
|
|
|
func apiReadReply(resp *http.Response, err error, ref any) error {
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
return err
|
|
}
|
|
|
|
if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
|
// Error code return
|
|
err := MlflowApiError{}
|
|
err2 := json.Unmarshal(body, &err)
|
|
if err2 != nil {
|
|
fmt.Println("Error when parsing reply: ", err2.Error(), "\n", string(body))
|
|
return err2
|
|
}
|
|
|
|
switch resp.StatusCode {
|
|
case 400:
|
|
return MlflowApiError400(err)
|
|
case 404:
|
|
return MlflowApiErrorNotFound(err)
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
|
|
err = json.Unmarshal(body, &ref)
|
|
|
|
//fmt.Println(string(body))
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type typeTag struct {
|
|
FieldName string
|
|
Omitempty bool
|
|
}
|
|
|
|
func parseJsonTag(tag reflect.StructTag, FieldName string) typeTag {
|
|
rv := typeTag{}
|
|
tag1 := strings.Split(string(tag), ":")
|
|
|
|
if tag1[0] == "json" {
|
|
tag2 := strings.Split(strings.ReplaceAll(tag1[1], `"`, ""), ",")
|
|
rv.FieldName = tag2[0]
|
|
rv.Omitempty = (len(tag2) == 2 && tag2[1] == "omitempty")
|
|
}
|
|
if rv.FieldName == "" {
|
|
rv.FieldName = FieldName
|
|
}
|
|
return rv
|
|
}
|
|
|
|
func UrlEncode(s any) string {
|
|
rv := ""
|
|
rvCount := 0
|
|
|
|
v := reflect.ValueOf(s)
|
|
typeOfV := v.Type()
|
|
|
|
values := make([]interface{}, v.NumField())
|
|
|
|
for i := 0; i < v.NumField(); i++ {
|
|
values[i] = v.Field(i).Interface()
|
|
name := typeOfV.Field(i).Name
|
|
tag := parseJsonTag(typeOfV.Field(i).Tag, name)
|
|
val := v.Field(i)
|
|
|
|
if tag.FieldName != "-" && (!tag.Omitempty || fmt.Sprint(val) != "") {
|
|
if rvCount > 0 {
|
|
rv += "&"
|
|
} else {
|
|
rv += "?"
|
|
}
|
|
rv += fmt.Sprintf("%s=%s", tag.FieldName, url.QueryEscape(fmt.Sprint(val)))
|
|
rvCount++
|
|
}
|
|
|
|
}
|
|
return rv
|
|
}
|