Compare commits
8 Commits
Author | SHA1 | Date |
---|---|---|
|
404a1a9a61 | |
|
d99e197d4b | |
|
c7aa20f7f1 | |
|
8de04ec577 | |
|
f83eae6e58 | |
|
1340962d12 | |
|
2420a73796 | |
|
a18920d818 |
23
api.go
23
api.go
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -25,6 +26,18 @@ func (e MlflowApiError) Error() string {
|
|||
return fmt.Sprint("Request Error " + string(e.ErrorCode) + ": " + e.ErrorDesc)
|
||||
}
|
||||
|
||||
type MlflowApiErrorNotFound MlflowApiError
|
||||
|
||||
func (e MlflowApiErrorNotFound) Error() string {
|
||||
return fmt.Sprint("Request Error 404 " + string(e.ErrorCode) + ": " + e.ErrorDesc)
|
||||
}
|
||||
|
||||
type MlflowApiError400 MlflowApiError
|
||||
|
||||
func (e MlflowApiError400) Error() string {
|
||||
return fmt.Sprint("Request Error 400 " + string(e.ErrorCode) + ": " + e.ErrorDesc)
|
||||
}
|
||||
|
||||
type Timestamp int64
|
||||
|
||||
func (t *Timestamp) Time() time.Time {
|
||||
|
@ -66,8 +79,16 @@ func apiReadReply(resp *http.Response, err error, ref any) error {
|
|||
fmt.Println("Error when parsing reply: ", err2.Error(), "\n", string(body))
|
||||
return err2
|
||||
}
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 400:
|
||||
return MlflowApiError400(err)
|
||||
case 404:
|
||||
return MlflowApiErrorNotFound(err)
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &ref)
|
||||
|
||||
|
@ -120,7 +141,7 @@ func UrlEncode(s any) string {
|
|||
} else {
|
||||
rv += "?"
|
||||
}
|
||||
rv += fmt.Sprintf("%s=%s", tag.FieldName, fmt.Sprint(val))
|
||||
rv += fmt.Sprintf("%s=%s", tag.FieldName, url.QueryEscape(fmt.Sprint(val)))
|
||||
rvCount++
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package mlflow
|
||||
|
||||
type Config struct {
|
||||
ApiURI string
|
||||
IgnoreSSL bool
|
||||
ApiURI string `json:"apiUrl,omitempty" yaml:"apiUri,omitempty"`
|
||||
IgnoreSSL bool `json:"ignoreSSL,omitempty" yaml:"ignoreSSL,omitempty"`
|
||||
}
|
||||
|
|
33
model.go
33
model.go
|
@ -2,6 +2,7 @@ package mlflow
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
|
@ -285,3 +286,35 @@ func GetModelVersionByAlias(conf Config, req GetModelVersionByAliasRequest) (*Mo
|
|||
|
||||
return &r.ModelVersion, nil
|
||||
}
|
||||
|
||||
func SearchModelVersionsAdvanced(conf Config, req SearchModelVersionsRequestAdvanced) ([]ModelVersion, error) {
|
||||
|
||||
var models []ModelVersion = []ModelVersion{}
|
||||
var j int = 0
|
||||
for {
|
||||
|
||||
resp, err := SearchModelVersions(conf, req.SearchModelVersionsRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, item := range resp.ModelVersions {
|
||||
if req.Stage == "" || req.Stage == item.CurrentStage {
|
||||
models = append(models, item)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if j > int(req.Limit) {
|
||||
fmt.Println("Limit exceed, terminating")
|
||||
break
|
||||
}
|
||||
if resp.NextPageToken == "" {
|
||||
break
|
||||
} else {
|
||||
req.PageToken = resp.NextPageToken
|
||||
}
|
||||
j++
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
package mlflow
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type ModelVersionStatus string
|
||||
|
||||
|
@ -46,8 +51,8 @@ type CreateRegisteredModelResponse struct {
|
|||
type ModelVersion struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`
|
||||
LastUpdateTimestamp Timestamp `json:"last_update_timestamp,omitempty"`
|
||||
CreationTimestamp TimestampMs `json:"creation_timestamp,omitempty"`
|
||||
LastUpdateTimestamp TimestampMs `json:"last_update_timestamp,omitempty"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
CurrentStage string `json:"current_stage,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
|
@ -60,6 +65,78 @@ type ModelVersion struct {
|
|||
Aliases []string `json:"aliases,omitempty"`
|
||||
}
|
||||
|
||||
func (ref *ModelVersion) GetTag(key string) *ModelVersionTag {
|
||||
for _, tag := range ref.Tags {
|
||||
if tag.Key == key {
|
||||
return &tag
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ref *ModelVersion) IsEqual(target *ModelVersion) bool {
|
||||
vRef := reflect.ValueOf(*ref)
|
||||
vTarget := reflect.ValueOf(*target)
|
||||
|
||||
typeOfV := vRef.Type()
|
||||
values := make([]interface{}, vRef.NumField())
|
||||
for i := 0; i < vRef.NumField(); i++ {
|
||||
values[i] = vRef.Field(i).Interface()
|
||||
refName := typeOfV.Field(i).Name
|
||||
refType := typeOfV.Field(i).Type.String()
|
||||
refVal := vRef.Field(i)
|
||||
|
||||
targetVal := reflect.Indirect(vTarget).FieldByName(refName)
|
||||
|
||||
switch refType {
|
||||
case "string", "mlflow.ModelVersionStatus":
|
||||
if refVal.String() != targetVal.String() {
|
||||
fmt.Println("NE STRING: ", refType, refName, refVal, targetVal, (refVal.String() == targetVal.String()))
|
||||
return false
|
||||
}
|
||||
case "mlflow.TimestampMs":
|
||||
if refVal.Int() != targetVal.Int() {
|
||||
fmt.Println("NE TimestampMs: ", refType, refName, refVal, targetVal, (refVal.Int() == targetVal.Int()))
|
||||
return false
|
||||
}
|
||||
case "[]mlflow.ModelVersionTag":
|
||||
sliceRef, _ := refVal.Interface().([]ModelVersionTag)
|
||||
sliceTarget, _ := targetVal.Interface().([]ModelVersionTag)
|
||||
|
||||
if len(sliceRef) != len(sliceTarget) {
|
||||
//fmt.Println("Not equal - len mismatch")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, val := range sliceRef {
|
||||
if tag2 := target.GetTag(val.Key); tag2 == nil || tag2.Value != val.Value {
|
||||
return false
|
||||
}
|
||||
}
|
||||
case "[]string":
|
||||
sliceRef, _ := refVal.Interface().([]string)
|
||||
sliceTarget, _ := targetVal.Interface().([]string)
|
||||
sort.Strings(sliceRef)
|
||||
sort.Strings(sliceTarget)
|
||||
if len(sliceRef) != len(sliceTarget) {
|
||||
//fmt.Println("Not equal - len mismatch")
|
||||
return false
|
||||
}
|
||||
|
||||
for xI, xRef := range sliceRef {
|
||||
if xRef != sliceTarget[xI] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
default:
|
||||
fmt.Println("Unknown: ", refType, refName, refVal, targetVal)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type RegisteredModel struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`
|
||||
|
@ -138,9 +215,10 @@ type GetDownloadUriForModelVersionArtifactsResponse struct {
|
|||
}
|
||||
|
||||
type SearchModelVersionsRequest struct {
|
||||
// Valid keys are '{'version_number', 'source_path', 'run_id', 'name'}'
|
||||
Filter string `json:"filter,omitempty"`
|
||||
MaxResults int64 `json:"max_results,omitempty"`
|
||||
OrderBy []string `json:"order_by,omitempty"`
|
||||
PageSize int64 `json:"max_results,omitempty"`
|
||||
OrderBy string `json:"order_by,omitempty"`
|
||||
PageToken string `json:"page_token,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -150,8 +228,8 @@ func (e *SearchModelVersionsRequest) serialize() []byte {
|
|||
}
|
||||
|
||||
func (e *SearchModelVersionsRequest) __init() *SearchModelVersionsRequest {
|
||||
if e.MaxResults == 0 {
|
||||
e.MaxResults = 1000
|
||||
if e.PageSize == 0 {
|
||||
e.PageSize = 1000
|
||||
}
|
||||
|
||||
return e
|
||||
|
@ -202,6 +280,7 @@ func (e *TransitionModelVersionStageRequest) serialize() []byte {
|
|||
type TransitionModelVersionStageResponse GetModelVersionResponse
|
||||
|
||||
type SearchRegisteredModelsRequest struct {
|
||||
// Valid keys are '{'name'}'
|
||||
Filter string `json:"filter,omitempty"`
|
||||
MaxResults int64 `json:"max_results,omitempty"`
|
||||
OrderBy string `json:"order_by,omitempty"`
|
||||
|
@ -300,3 +379,9 @@ func (e *DeleteRegisteredModelAliasRequest) serialize() []byte {
|
|||
}
|
||||
|
||||
type GetModelVersionByAliasResponse GetModelVersionResponse
|
||||
|
||||
type SearchModelVersionsRequestAdvanced struct {
|
||||
SearchModelVersionsRequest
|
||||
Stage string
|
||||
Limit int64
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue