mlflow-client-go/run.go

202 lines
4.7 KiB
Go
Raw Normal View History

2024-05-15 08:14:29 +00:00
package mlflow
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)
type SearchRunReply struct {
Runs []Run `json:"runs,omitempty"`
NextPageToken string `json:"next_page_token,omitempty"`
}
type UpdateRunRequest struct {
RunId string `json:"run_id,omitempty"`
Status RunStatus `json:"status,omitempty"`
EndTime TimestampMs `json:"end_time,omitempty"`
RunName string `json:"run_name,omitempty"`
}
type UpdateRunReply struct {
RunInfo RunInfo `json:"run_info,omitempty"`
}
func CreateRun(conf Config, req CreateRunRequest) (*Run, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/create", "application/json", bytes.NewReader([]byte(req.serialize())))
r := CreateRunReply{}
err = apiReadReply(resp, err, &r)
if err != nil {
return nil, err
}
return &r.Run, nil
}
func DeleteRun(conf Config, id string) error {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/delete", "application/json", bytes.NewReader([]byte(`{"run_id":"`+id+`"}`)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func RestoreRun(conf Config, id string) (*Run, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/restore", "application/json", bytes.NewReader([]byte(`{"run_id":"`+id+`"}`)))
err = apiReadReply(resp, err, nil)
if err != nil {
return nil, err
}
return GetRun(conf, id)
}
func GetRun(conf Config, id string) (*Run, error) {
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/runs/get" + "?run_id=" + id)
run := GetRunReply{}
err = apiReadReply(resp, err, &run)
if err != nil {
return nil, err
}
return &run.Run, nil
}
func SearchRuns(conf Config, opts SearchRunRequest) (*SearchRunReply, error) {
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/search", "application/json", opts.getReader())
rv := SearchRunReply{}
err = apiReadReply(resp, err, &rv)
if err != nil {
return nil, err
}
return &rv, nil
}
func UpdateRun(conf Config, opts UpdateRunRequest) (*RunInfo, error) {
args, err := json.Marshal(opts)
if err != nil {
return nil, err
}
rv := UpdateRunReply{}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/update", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, &rv)
if err != nil {
return nil, err
}
return &rv.RunInfo, nil
}
func LogMetric(conf Config, opts LogMetricRequest) error {
args, err := json.Marshal(opts.__init())
if err != nil {
return err
}
fmt.Println(opts)
fmt.Println(string(args))
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-metric", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func LogBatch(conf Config, opts LogBatchRequest) error {
for i, metric := range opts.Metrics {
opts.Metrics[i] = *metric.__init()
}
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-batch", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func LogInputs(conf Config, opts LogInputsRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-inputs", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func SetTag(conf Config, opts SetTagRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/set-tag", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func DeleteTag(conf Config, opts DeleteTagRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/delete-tag", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func LogParam(conf Config, opts LogParamRequest) error {
args, err := json.Marshal(opts)
if err != nil {
return err
}
resp, err := http.Post(conf.ApiURI+"/api/2.0/mlflow/runs/log-parameter", "application/json", bytes.NewReader([]byte(args)))
err = apiReadReply(resp, err, nil)
if err != nil {
return err
}
return nil
}
func GetMetricHistory(conf Config, opts GetMetricHistoryRequest) (*GetMetricHistoryResponse, error) {
rv := GetMetricHistoryResponse{}
resp, err := http.Get(conf.ApiURI + "/api/2.0/mlflow/metrics/get-history" + UrlEncode(*opts.__init()))
err = apiReadReply(resp, err, &rv)
if err != nil {
return nil, err
}
return &rv, nil
}