From c7aa20f7f1a161b252d54c84bbf18da35eae01a2 Mon Sep 17 00:00:00 2001 From: Pavel Dmitriev Date: Tue, 21 May 2024 17:48:56 +0300 Subject: [PATCH] added func (ref *ModelVersion) IsEqual(target *ModelVersion) bool --- model_types.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/model_types.go b/model_types.go index 3ba8f4c..d7f6c33 100644 --- a/model_types.go +++ b/model_types.go @@ -1,6 +1,11 @@ package mlflow -import "encoding/json" +import ( + "encoding/json" + "fmt" + "reflect" + "sort" +) type ModelVersionStatus string @@ -60,6 +65,78 @@ type ModelVersion struct { Aliases []string `json:"aliases,omitempty"` } +func (ref *ModelVersion) GetTag(key string) *ModelVersionTag { + for _, tag := range ref.Tags { + if tag.Key == key { + return &tag + } + } + return nil +} + +func (ref *ModelVersion) IsEqual(target *ModelVersion) bool { + vRef := reflect.ValueOf(*ref) + vTarget := reflect.ValueOf(*target) + + typeOfV := vRef.Type() + values := make([]interface{}, vRef.NumField()) + for i := 0; i < vRef.NumField(); i++ { + values[i] = vRef.Field(i).Interface() + refName := typeOfV.Field(i).Name + refType := typeOfV.Field(i).Type.String() + refVal := vRef.Field(i) + + targetVal := reflect.Indirect(vTarget).FieldByName(refName) + + switch refType { + case "string", "mlflow.ModelVersionStatus": + if refVal.String() != targetVal.String() { + fmt.Println("NE STRING: ", refType, refName, refVal, targetVal, (refVal.String() == targetVal.String())) + return false + } + case "mlflow.TimestampMs": + if refVal.Int() != targetVal.Int() { + fmt.Println("NE TimestampMs: ", refType, refName, refVal, targetVal, (refVal.Int() == targetVal.Int())) + return false + } + case "[]mlflow.ModelVersionTag": + sliceRef, _ := refVal.Interface().([]ModelVersionTag) + sliceTarget, _ := targetVal.Interface().([]ModelVersionTag) + + if len(sliceRef) != len(sliceTarget) { + //fmt.Println("Not equal - len mismatch") + return false + } + + for _, val := range sliceRef { + if tag2 := target.GetTag(val.Key); tag2 == nil || tag2.Value != val.Value { + return false + } + } + case "[]string": + sliceRef, _ := refVal.Interface().([]string) + sliceTarget, _ := targetVal.Interface().([]string) + sort.Strings(sliceRef) + sort.Strings(sliceTarget) + if len(sliceRef) != len(sliceTarget) { + //fmt.Println("Not equal - len mismatch") + return false + } + + for xI, xRef := range sliceRef { + if xRef != sliceTarget[xI] { + return false + } + } + default: + fmt.Println("Unknown: ", refType, refName, refVal, targetVal) + } + + } + + return true +} + type RegisteredModel struct { Name string `json:"name,omitempty"` CreationTimestamp Timestamp `json:"creation_timestamp,omitempty"`