package okrrpclogic

import (
	"context"
	"fmt"
	"oa-server/app/oacenter/model/okr"
	"oa-server/app/oacenter/oa_rpc/internal/svc"
	"oa-server/app/oacenter/oa_rpc/oa"

	"github.com/shopspring/decimal"
	"github.com/zeromicro/go-zero/core/logx"
)

type OkrApplyAckLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewOkrApplyAckLogic(ctx context.Context, svcCtx *svc.ServiceContext) *OkrApplyAckLogic {
	return &OkrApplyAckLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

func (l *OkrApplyAckLogic) OkrApplyAck(in *oa.OkrApplyAckReq) (*oa.EmptyResponse, error) {
	okrPeriod, err := checkOkrStatus(l.svcCtx.OkrPeriodModel, int64(in.OkrId))
	if err != nil {
		return nil, err
	}
	if okrPeriod.Owner != in.ApplyAckBy {
		return nil, ErrPermissionDenied
	}
	// 权重校验
	err = checkWeight(l.ctx, l.svcCtx.OkrTaskModel, l.svcCtx.OkrAlignmentModel, int64(in.OkrId))
	if err != nil {
		return nil, err
	}
	okrPeriod.ApprovalStatus = int64(oa.OkrStatus_OKR_STATUS_ACKING)
	err = l.svcCtx.OkrPeriodModel.Update(l.ctx, okrPeriod)
	if err != nil {
		logx.Errorw("更新okr状态失败", logx.Field("err", err), logx.Field("okr_id", in.OkrId))
		return nil, err
	}

	return &oa.EmptyResponse{}, nil
}

// checkWeight 检查权重是否满足要求
// 所有O是100
// 每个O下的KR权重和是100
// 每个KR下的任务权重和是100
// 允许只有O没有KR
// 允许只有KR没有task
// KR, Task分为自己的、其他人对齐过来的
func checkWeight(ctx context.Context, taskModel okr.XOKrTaskModel, okrAlignmentModel okr.XOkrAlignmentModel, okrPeriodId int64) error {
	taskList, err := taskModel.GetByOkrPeriodId(okrPeriodId)
	if err != nil {
		logx.Errorw("获取任务列表失败", logx.Field("err", err), logx.Field("okr_id", okrPeriodId))
		return err
	}

	oMap := make(map[int64]*okr.XOKrTask)      // O的ID -> O
	krMap := make(map[int64][]*okr.XOKrTask)   // O的ID -> KR列表
	taskMap := make(map[int64][]*okr.XOKrTask) // KR的ID -> Task列表
	var oIdList, krIdList []int64
	// 首先找出所有的O、KR和Task
	for _, task := range taskList {
		switch task.EntityType {
		case OkrEntityTypeObjective:
			oMap[task.EntityId] = task
			oIdList = append(oIdList, task.EntityId)
		case OkrEntityTypeKeyresult:
			parentId := task.ParentEntityId
			krMap[parentId] = append(krMap[parentId], task)
			krIdList = append(krIdList, task.EntityId)
		case OkrEntityTypeTask:
			parentId := task.ParentEntityId
			taskMap[parentId] = append(taskMap[parentId], task)
		}
	}

	// 检查所有O的权重和是否为100
	totalOWeight := decimal.NewFromFloat(0)
	for _, o := range oMap {
		if o.Weight == 0 {
			return fmt.Errorf("O[%s]的权重不能为0", o.Content)
		}
		totalOWeight = totalOWeight.Add(decimal.NewFromFloat(o.Weight))
	}

	totalOWeight = totalOWeight.Truncate(2)
	if !totalOWeight.Equal(HundredPercent) {
		w, _ := totalOWeight.Float64()
		return fmt.Errorf("所有O的权重和应为100, 当前为%.2f", w)
	}

	// 查询对齐过来的KR
	alignedKRList, err := okrAlignmentModel.GetAlignedChildren(ctx, oIdList)
	if err != nil {
		logx.Errorw("failed to get aligned keyresult list from x_okr_alignment on checkWeight", logx.Field("error", err), logx.Field("okr_id", okrPeriodId))
		return err
	}
	alignedKRMap := alignedKRList.ByAlignWithEntityId()
	// 对齐过来的KR 和自建的KR 不能同时为空。
	if len(alignedKRList) == 0 && len(krIdList) == 0 {
		return fmt.Errorf("KR不能为空")
	}
	// 查询对齐过来的task
	alignedTaskList, err := okrAlignmentModel.GetAlignedChildren(ctx, krIdList)
	if err != nil {
		logx.Errorw("failed to get aligned task list from x_okr_alignment on checkWeight", logx.Field("error", err), logx.Field("okr_id", okrPeriodId))
		return err
	}
	alignedTaskMap := alignedTaskList.ByAlignWithEntityId()

	// 检查每个O下的KR权重和是否为100
	for oId, o := range oMap {
		krs := krMap[oId]
		alignedKRs := alignedKRMap[oId]
		if len(krs) == 0 && len(alignedKRs) == 0 {
			// 没有自建的，也没有对齐过来的kr
			continue
		}

		totalKRWeight := decimal.NewFromFloat(0)
		for _, kr := range krs {
			if kr.Weight == 0 {
				return fmt.Errorf("KR[%s]的权重不能为0", kr.Content)
			}
			totalKRWeight = totalKRWeight.Add(decimal.NewFromFloat(kr.Weight))
		}

		for _, v := range alignedKRs {
			if v.AlignWithWeight == 0 {
				return fmt.Errorf("KR的权重不能为0")
			}
			totalKRWeight = totalKRWeight.Add(decimal.NewFromFloat(v.AlignWithWeight))
		}

		totalKRWeight = totalKRWeight.Truncate(2)
		if !totalKRWeight.Equal(HundredPercent) {
			w, _ := totalKRWeight.Float64()
			return fmt.Errorf("O[%s]下的KR权重和应为100, 当前为%.2f", o.Content, w)
		}

		// 检查每个KR下的任务权重和是否为100
		for _, kr := range krs {
			tasks := taskMap[kr.EntityId]
			alignedTasks := alignedTaskMap[kr.EntityId]
			if len(tasks) == 0 && len(alignedTasks) == 0 {
				// 没有自建的、也没有对齐过来的task
				continue
			}

			totalTaskWeight := decimal.NewFromFloat(0)
			for _, task := range tasks {
				if task.Weight == 0 {
					return fmt.Errorf("Task[%s]的权重不能为0", task.Content)
				}
				totalTaskWeight = totalTaskWeight.Add(decimal.NewFromFloat(task.Weight))
			}

			for _, v := range alignedTasks {
				if v.AlignWithWeight == 0 {
					return fmt.Errorf("Task的权重不能为0")
				}
				totalTaskWeight = totalTaskWeight.Add(decimal.NewFromFloat(v.AlignWithWeight))
			}

			totalTaskWeight = totalTaskWeight.Truncate(2)
			if !totalTaskWeight.Equal(HundredPercent) {
				w, _ := totalTaskWeight.Float64()
				return fmt.Errorf("KR[%s]下的任务权重和应为100, 当前为%.2f", kr.Content, w)
			}
		}
	}

	return nil
}
