package okrrpclogic

import (
	"context"
	"time"

	"oa-server/app/oacenter/model/okr"
	"oa-server/app/oacenter/oa_rpc/internal/svc"
	"oa-server/app/oacenter/oa_rpc/oa"

	"github.com/zeromicro/go-zero/core/logx"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

type ListOkrPeriodLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewListOkrPeriodLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ListOkrPeriodLogic {
	return &ListOkrPeriodLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// 获取OKR周期列表
func (l *ListOkrPeriodLogic) ListOkrPeriod(in *oa.ListOkrPeriodReq) (*oa.ListOkrPeriodResp, error) {
	startDate, endDate, err := l.Validate(in)
	if err != nil {
		return nil, err
	}

	req := okr.FindOkrPeriodByUsersAndTimeQuery{
		Emails:    in.GetEmailList(),
		StartDate: startDate.Local().Format(time.DateOnly),
		EndDate:   endDate.Local().Format(time.DateOnly),
		PageNum:   in.GetPageNum(),
		PageSize:  in.GetPageSize(),
	}
	cnt, okrPeriodList, err := l.svcCtx.OkrPeriodModel.FindOkrPeriodByUsersAndTime(l.ctx, req)
	if err != nil {
		logx.Errorw("failed to query x_okr_period on ListOkrPeriod", logx.Field("error", err))
		return nil, ErrOkrPeriodQueryFailed
	}

	list := make([]*oa.OkrPeriod, len(okrPeriodList))
	for i, v := range okrPeriodList {
		list[i] = &oa.OkrPeriod{
			PeriodId:       v.PeriodId,
			Owner:          v.Owner,
			StartDate:      v.StartDate.Local().Format(time.DateOnly),
			EndDate:        v.EndDate.Local().Format(time.DateOnly),
			ApprovalStatus: oa.OkrStatus(v.ApprovalStatus),
			CreatedAt:      v.CreatedAt.Local().Format(time.RFC3339),
			UpdatedAt:      v.UpdatedAt.Local().Format(time.RFC3339),
		}
	}

	return &oa.ListOkrPeriodResp{Total: cnt, List: list}, nil
}

func (l ListOkrPeriodLogic) Validate(in *oa.ListOkrPeriodReq) (startDate, endDate time.Time, err error) {
	if len(in.GetEmailList()) == 0 {
		return startDate, endDate, status.Error(codes.InvalidArgument, "email_list参数无效")
	}

	startDate, err = time.Parse(time.DateOnly, in.GetStartDate())
	if err != nil {
		return startDate, endDate, status.Error(codes.InvalidArgument, "start_date参数无效")
	}
	endDate, err = time.Parse(time.DateOnly, in.GetEndDate())
	if err != nil {
		return startDate, endDate, status.Error(codes.InvalidArgument, "end_date参数无效")
	}

	return
}
