package reportrpclogic

import (
	"context"
	"errors"

	"oa-server/app/oacenter/model/report"
	"oa-server/app/oacenter/oa_rpc/internal/logic/common"
	"oa-server/app/oacenter/oa_rpc/internal/svc"
	"oa-server/app/oacenter/oa_rpc/oa"

	"github.com/zeromicro/go-zero/core/logx"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
)

type UpdateReportLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewUpdateReportLogic(ctx context.Context, svcCtx *svc.ServiceContext) *UpdateReportLogic {
	return &UpdateReportLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// 更新工作汇报
func (l *UpdateReportLogic) UpdateReport(in *oa.UpdateReportReq) (*oa.UpdateReportResp, error) {
	if err := l.Validate(in); err != nil {
		return nil, err
	}

	rpt, err := l.svcCtx.ReportModel.FindOne(l.ctx, in.GetReportId())
	if err != nil {
		if errors.Is(err, report.ErrNotFound) {
			return nil, ErrReportNotFound
		}
		logx.Errorw("failed to query x_report on UpdateReport.ReportModel.FindOne", logx.Field("error", err))
		return nil, ErrReportQueryFailed
	}
	if !rpt.CanUpdate(in.GetUserEmail()) {
		return nil, common.ErrPermissionDenied
	}

	// update
	rpt.IsDelayed = convertBoolToInt64(in.GetIsDelayed())
	rpt.DelayReason = in.GetDelayReason()
	rpt.HasRisk = convertBoolToInt64(in.GetHasRisk())
	rpt.RiskDesc = in.GetRiskDesc()
	rpt.Status = uint64(convertReportStatus(in.GetIsDraft()))

	// 开启事务更新工作汇报
	err = l.svcCtx.ReportModel.TransCtx(l.ctx, func(ctx context.Context, s sqlx.Session) error {
		// 更新x_report
		err := l.svcCtx.ReportModel.TransUpdateCtx(ctx, s, rpt)
		if err != nil {
			logx.Errorw("failed to update x_report on UpdateReport.ReportModel.TransUpdateCtx", logx.Field("error", err))
			return err
		}

		// 删除x_report_item
		err = l.svcCtx.ReportItemModel.TransDeleteByReportIdCtx(ctx, s, rpt.Id)
		if err != nil {
			logx.Errorw("failed to delete x_report_item on UpdateReport.ReportItemModel.TransDeleteByReportIdCtx", logx.Field("error", err), logx.Field("report_id", rpt.Id))
			return err
		}

		inputItems := in.GetReportItemList()
		var mentions []*report.XReportMention
		items := make([]*report.XReportItem, len(inputItems))
		for i, v := range inputItems {
			items[i] = &report.XReportItem{
				ReportId: rpt.Id,
				Content:  v.Content,
			}
		}

		// 创建x_report_item
		for i, v := range items {
			result, err := l.svcCtx.ReportItemModel.TransInsertCtx(ctx, s, v)
			if err != nil {
				logx.Errorw("failed to create x_report_item on UpdateReport.ReportItemModel.TransInsertCtx", logx.Field("error", err))
				return err
			}
			itemId, err := result.LastInsertId()
			if err != nil {
				logx.Errorw("failed to get x_report_item LastInsertId on UpdateReport.ReportItemModel.TransInsertCtx", logx.Field("error", err))
				return err
			}
			items[i].Id = uint64(itemId)

			// 构建mention
			for _, m := range inputItems[i].MentionList {
				mentions = append(mentions, &report.XReportMention{
					ReportId:         rpt.Id,
					ReportType:       rpt.ReportType,
					ReportDate:       rpt.ReportDate,
					ReportItemId:     uint64(itemId),
					InitiatorEmail:   in.GetUserEmail(),
					MentionUserEmail: m,
				})
			}
		}

		// 删除x_report_mention
		err = l.svcCtx.ReportMentionModel.TransDeleteByReportIdCtx(ctx, s, rpt.Id)
		if err != nil {
			logx.Errorw("failed to delete x_report_mention on UpdateReport.ReportMentionModel.TransDeleteByReportIdCtx", logx.Field("error", err))
			return err
		}

		// 创建x_report_mention
		for _, v := range mentions {
			result, err := l.svcCtx.ReportMentionModel.TransInsertCtx(ctx, s, v)
			if err != nil {
				logx.Errorw("failed to create x_report_mention on UpdateReport.ReportMentionModel.TransInsertCtx", logx.Field("error", err))
				return err
			}
			if _, err = result.LastInsertId(); err != nil {
				logx.Errorw("failed to get x_report_mention LastInsertId", logx.Field("error", err))
				return err
			}
		}
		return nil
	})
	if err != nil {
		return nil, ErrReportUpdateFailed
	}

	return &oa.UpdateReportResp{}, nil
}

func (l *UpdateReportLogic) Validate(in *oa.UpdateReportReq) error {
	if in.GetUserEmail() == "" {
		return ErrUserEmailRequired
	}
	if in.GetReportId() == 0 {
		return ErrReportIdRequired
	}
	return nil
}
