Compare commits

..

No commits in common. "main" and "v1.0.3" have entirely different histories.
main ... v1.0.3

4 changed files with 13 additions and 152 deletions

23
api.go
View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"reflect" "reflect"
"strings" "strings"
"time" "time"
@ -26,18 +25,6 @@ func (e MlflowApiError) Error() string {
return fmt.Sprint("Request Error " + string(e.ErrorCode) + ": " + e.ErrorDesc) return fmt.Sprint("Request Error " + string(e.ErrorCode) + ": " + e.ErrorDesc)
} }
type MlflowApiErrorNotFound MlflowApiError
func (e MlflowApiErrorNotFound) Error() string {
return fmt.Sprint("Request Error 404 " + string(e.ErrorCode) + ": " + e.ErrorDesc)
}
type MlflowApiError400 MlflowApiError
func (e MlflowApiError400) Error() string {
return fmt.Sprint("Request Error 400 " + string(e.ErrorCode) + ": " + e.ErrorDesc)
}
type Timestamp int64 type Timestamp int64
func (t *Timestamp) Time() time.Time { func (t *Timestamp) Time() time.Time {
@ -79,16 +66,8 @@ func apiReadReply(resp *http.Response, err error, ref any) error {
fmt.Println("Error when parsing reply: ", err2.Error(), "\n", string(body)) fmt.Println("Error when parsing reply: ", err2.Error(), "\n", string(body))
return err2 return err2
} }
switch resp.StatusCode {
case 400:
return MlflowApiError400(err)
case 404:
return MlflowApiErrorNotFound(err)
default:
return err return err
} }
}
err = json.Unmarshal(body, &ref) err = json.Unmarshal(body, &ref)
@ -141,7 +120,7 @@ func UrlEncode(s any) string {
} else { } else {
rv += "?" rv += "?"
} }
rv += fmt.Sprintf("%s=%s", tag.FieldName, url.QueryEscape(fmt.Sprint(val))) rv += fmt.Sprintf("%s=%s", tag.FieldName, fmt.Sprint(val))
rvCount++ rvCount++
} }

View File

@ -1,6 +1,6 @@
package mlflow package mlflow
type Config struct { type Config struct {
ApiURI string `json:"apiUrl,omitempty" yaml:"apiUri,omitempty"` ApiURI string
IgnoreSSL bool `json:"ignoreSSL,omitempty" yaml:"ignoreSSL,omitempty"` IgnoreSSL bool
} }

View File

@ -2,7 +2,6 @@ package mlflow
import ( import (
"bytes" "bytes"
"fmt"
"net/http" "net/http"
) )
@ -286,35 +285,3 @@ func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*Mo
return &r.ModelVersion, nil return &r.ModelVersion, nil
} }
func SearchModelVersionsAdvanced(conf Config, req SearchModelVersionsRequestAdvanced) ([]ModelVersion, error) {
var models []ModelVersion = []ModelVersion{}
var j int = 0
for {
resp, err := SearchModelVersions(conf, req.SearchModelVersionsRequest)
if err != nil {
return nil, err
}
for _, item := range resp.ModelVersions {
if req.Stage == "" || req.Stage == item.CurrentStage {
models = append(models, item)
continue
}
}
if j > int(req.Limit) {
fmt.Println("Limit exceed, terminating")
break
}
if resp.NextPageToken == "" {
break
} else {
req.PageToken = resp.NextPageToken
}
j++
}
return models, nil
}

View File

@ -1,11 +1,6 @@
package mlflow package mlflow
import ( import "encoding/json"
"encoding/json"
"fmt"
"reflect"
"sort"
)
type ModelVersionStatus string type ModelVersionStatus string
@ -51,8 +46,8 @@ type CreateRegisteredModelResponse struct {
type ModelVersion struct { type ModelVersion struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
CreationTimestamp TimestampMs `json:"creation_timestamp,omitempty"` CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`
LastUpdateTimestamp TimestampMs `json:"last_update_timestamp,omitempty"` LastUpdateTimestamp Timestamp `json:"last_update_timestamp,omitempty"`
UserId string `json:"user_id,omitempty"` UserId string `json:"user_id,omitempty"`
CurrentStage string `json:"current_stage,omitempty"` CurrentStage string `json:"current_stage,omitempty"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
@ -65,78 +60,6 @@ type ModelVersion struct {
Aliases []string `json:"aliases,omitempty"` Aliases []string `json:"aliases,omitempty"`
} }
func (ref *ModelVersion) GetTag(key string) *ModelVersionTag {
for _, tag := range ref.Tags {
if tag.Key == key {
return &tag
}
}
return nil
}
func (ref *ModelVersion) IsEqual(target *ModelVersion) bool {
vRef := reflect.ValueOf(*ref)
vTarget := reflect.ValueOf(*target)
typeOfV := vRef.Type()
values := make([]interface{}, vRef.NumField())
for i := 0; i < vRef.NumField(); i++ {
values[i] = vRef.Field(i).Interface()
refName := typeOfV.Field(i).Name
refType := typeOfV.Field(i).Type.String()
refVal := vRef.Field(i)
targetVal := reflect.Indirect(vTarget).FieldByName(refName)
switch refType {
case "string", "mlflow.ModelVersionStatus":
if refVal.String() != targetVal.String() {
fmt.Println("NE STRING: ", refType, refName, refVal, targetVal, (refVal.String() == targetVal.String()))
return false
}
case "mlflow.TimestampMs":
if refVal.Int() != targetVal.Int() {
fmt.Println("NE TimestampMs: ", refType, refName, refVal, targetVal, (refVal.Int() == targetVal.Int()))
return false
}
case "[]mlflow.ModelVersionTag":
sliceRef, _ := refVal.Interface().([]ModelVersionTag)
sliceTarget, _ := targetVal.Interface().([]ModelVersionTag)
if len(sliceRef) != len(sliceTarget) {
//fmt.Println("Not equal - len mismatch")
return false
}
for _, val := range sliceRef {
if tag2 := target.GetTag(val.Key); tag2 == nil || tag2.Value != val.Value {
return false
}
}
case "[]string":
sliceRef, _ := refVal.Interface().([]string)
sliceTarget, _ := targetVal.Interface().([]string)
sort.Strings(sliceRef)
sort.Strings(sliceTarget)
if len(sliceRef) != len(sliceTarget) {
//fmt.Println("Not equal - len mismatch")
return false
}
for xI, xRef := range sliceRef {
if xRef != sliceTarget[xI] {
return false
}
}
default:
fmt.Println("Unknown: ", refType, refName, refVal, targetVal)
}
}
return true
}
type RegisteredModel struct { type RegisteredModel struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"` CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`
@ -215,10 +138,9 @@ type GetDownloadUriForModelVersionArtifactsResponse struct {
} }
type SearchModelVersionsRequest struct { type SearchModelVersionsRequest struct {
// Valid keys are '{'version_number', 'source_path', 'run_id', 'name'}'
Filter string `json:"filter,omitempty"` Filter string `json:"filter,omitempty"`
PageSize int64 `json:"max_results,omitempty"` MaxResults int64 `json:"max_results,omitempty"`
OrderBy string `json:"order_by,omitempty"` OrderBy []string `json:"order_by,omitempty"`
PageToken string `json:"page_token,omitempty"` PageToken string `json:"page_token,omitempty"`
} }
@ -228,8 +150,8 @@ func (e *SearchModelVersionsRequest) serialize() []byte {
} }
func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest { func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest {
if e.PageSize == 0 { if e.MaxResults == 0 {
e.PageSize = 1000 e.MaxResults = 1000
} }
return e return e
@ -280,7 +202,6 @@ func (e *TransitionModelVersionStageRequest) serialize() []byte {
type TransitionModelVersionStageResponse GetModelVersionResponse type TransitionModelVersionStageResponse GetModelVersionResponse
type SearchRegisteredModelsRequest struct { type SearchRegisteredModelsRequest struct {
// Valid keys are '{'name'}'
Filter string `json:"filter,omitempty"` Filter string `json:"filter,omitempty"`
MaxResults int64 `json:"max_results,omitempty"` MaxResults int64 `json:"max_results,omitempty"`
OrderBy string `json:"order_by,omitempty"` OrderBy string `json:"order_by,omitempty"`
@ -379,9 +300,3 @@ func (e *DeleteRegisteredModelAliasRequest) serialize() []byte {
} }
type GetModelVersionByAliasResponse GetModelVersionResponse type GetModelVersionByAliasResponse GetModelVersionResponse
type SearchModelVersionsRequestAdvanced struct {
SearchModelVersionsRequest
Stage string
Limit int64
}