package reportrpclogic

import (
	"context"
	"time"

	"oa-server/app/oacenter/model/report"
	"oa-server/app/oacenter/oa_rpc/internal/svc"
	"oa-server/app/oacenter/oa_rpc/oa"

	"github.com/zeromicro/go-zero/core/logx"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

type ListReportLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewListReportLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ListReportLogic {
	return &ListReportLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// 获取工作汇报列表
func (l *ListReportLogic) ListReport(in *oa.ListReportReq) (*oa.ListReportResp, error) {
	startDate, endDate, err := l.Validate(in)
	if err != nil {
		return nil, err
	}
	queryReq := report.FindUserReportsReq{
		UserEmail:       in.GetTargetUserEmail(),
		ReportStartDate: startDate,
		ReportEndDate:   endDate,
		ReportStatus:    uint64(oa.ReportStatus_DONE_RS),
		PageNum:         in.GetPageNum(),
		PageSize:        in.GetPageSize(),
	}
	if in.GetUserEmail() == in.GetTargetUserEmail() {
		// 查看自己的工作汇报，则看所有状态的。
		queryReq.ReportStatus = 0
	}
	total, reports, err := l.svcCtx.ReportModel.FindUserReports(l.ctx, queryReq)
	if err != nil {
		logx.Errorw("failed to query x_report on ListReport.ReportModel.FindUserReports", logx.Field("error", err))
		return nil, ErrReportQueryFailed
	}
	if len(reports) == 0 {
		return &oa.ListReportResp{}, nil
	}

	reportIds := reports.Ids()

	reportItems, err := l.svcCtx.ReportItemModel.FindByReportIds(l.ctx, reportIds)
	if err != nil {
		logx.Errorw("failed to query x_report_item on ListReport.ReportItemModel.FindByReportIds", logx.Field("error", err))
		return nil, ErrReportItemQueryFailed
	}

	reportComments, err := l.svcCtx.ReportCommentModel.FindByReportIds(l.ctx, reportIds)
	if err != nil {
		logx.Errorw("failed to query x_report_comment on ListReport.ReportCommentModel.FindByReportIds", logx.Field("error", err))
		return nil, ErrReportCommentQueryFailed
	}

	reportList := make([]*oa.Report, len(reports))
	itemMap := reportItems.ByReportId()
	commentMap := reportComments.ByReportId()
	for i, v := range reports {
		itemList := itemMap[v.Id]
		commentList := commentMap[v.Id]
		items := make([]*oa.ReportItem, len(itemList))
		for j, m := range itemList {
			items[j] = &oa.ReportItem{
				ReportItemId: m.Id,
				Content:      m.Content,
			}
		}
		comments := make([]*oa.ReportComment, len(commentList))
		for j, m := range commentList {
			comments[j] = &oa.ReportComment{
				ReportCommentId: m.Id,
				UserEmail:       m.UserEmail,
				Content:         m.Content,
				CreatedAt:       m.CreatedAt.Local().Format(time.DateTime),
				UpdatedAt:       m.UpdatedAt.Local().Format(time.DateTime),
			}
		}

		reportList[i] = &oa.Report{
			ReportId:          v.Id,
			ReportType:        oa.ReportType(v.ReportType),
			ReportDate:        v.ReportDate.Local().Format(time.DateOnly),
			IsDelayed:         convertIn64ToBool(v.IsDelayed),
			DelayReason:       v.DelayReason,
			HasRisk:           convertIn64ToBool(v.HasRisk),
			RiskDesc:          v.RiskDesc,
			ReportStatus:      oa.ReportStatus(v.Status),
			CreatedAt:         v.CreatedAt.Local().Format(time.DateTime),
			UpdatedAt:         v.UpdatedAt.Local().Format(time.DateTime),
			ReportItemList:    items,
			ReportCommentList: comments,
		}
	}

	return &oa.ListReportResp{Total: total, List: reportList}, nil
}

func (l *ListReportLogic) Validate(in *oa.ListReportReq) (startDate, endDate time.Time, err error) {
	if in.GetTargetUserEmail() == "" {
		err = status.Error(codes.InvalidArgument, "target_user_email参数必填")
		return
	}
	if in.GetUserEmail() == "" {
		err = ErrUserEmailRequired
		return
	}

	startDate, err = time.ParseInLocation(time.DateOnly, in.GetReportStartDate(), time.Local)
	if err != nil {
		err = status.Error(codes.InvalidArgument, "report_start_date参数格式错误")
		return
	}
	endDate, err = time.ParseInLocation(time.DateOnly, in.GetReportEndDate(), time.Local)
	if err != nil {
		err = status.Error(codes.InvalidArgument, "report_end_date参数格式错误")
		return
	}
	if endDate.Before(startDate) {
		err = status.Error(codes.InvalidArgument, "report_end_date在report_start_date之前")
		return
	}

	if in.GetPageNum() <= 0 {
		in.PageNum = 1
	}
	if in.GetPageSize() <= 0 || in.GetPageSize() > ReportMaxPageSize {
		in.PageSize = ReportDefaultPageSize
	}

	return
}
