package reportrpclogic

import (
	"context"
	"errors"
	"time"

	"oa-server/app/oacenter/model/report"
	"oa-server/app/oacenter/oa_rpc/internal/svc"
	"oa-server/app/oacenter/oa_rpc/oa"

	"github.com/zeromicro/go-zero/core/logx"
	"github.com/zeromicro/go-zero/core/stores/sqlx"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

type CreateReportLogic struct {
	ctx    context.Context
	svcCtx *svc.ServiceContext
	logx.Logger
}

func NewCreateReportLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CreateReportLogic {
	return &CreateReportLogic{
		ctx:    ctx,
		svcCtx: svcCtx,
		Logger: logx.WithContext(ctx),
	}
}

// 创建工作汇报
func (l *CreateReportLogic) CreateReport(in *oa.CreateReportReq) (*oa.CreateReportResp, error) {
	if err := l.Validate(in); err != nil {
		return nil, err
	}
	reportDate, err := time.ParseInLocation(time.DateOnly, in.GetReportDate(), time.Local)
	if err != nil {
		return nil, status.Error(codes.InvalidArgument, "report_date参数格式错误")
	}

	rpt, err := l.svcCtx.ReportModel.FindOneByUserEmailReportDateReportType(l.ctx, in.GetUserEmail(), reportDate, uint64(in.GetReportType()))
	if err != nil && !errors.Is(err, report.ErrNotFound) {
		logx.Errorw("failed to query x_report on CreateReport.FindOneByUserEmailReportDateReportType", logx.Field("error", err))
		return nil, ErrReportQueryFailed
	}
	if rpt != nil {
		logx.Error("user's report exists", logx.Field("user_email", in.GetUserEmail()), logx.Field("report_date", reportDate), logx.Field("report_type", in.GetReportType()))
		return nil, ErrReportExists
	}

	rpt = &report.XReport{
		UserEmail:   in.GetUserEmail(),
		ReportType:  uint64(in.GetReportType()),
		ReportDate:  reportDate,
		IsDelayed:   convertBoolToInt64(in.GetIsDelayed()),
		DelayReason: in.GetDelayReason(),
		HasRisk:     convertBoolToInt64(in.GetHasRisk()),
		RiskDesc:    in.GetRiskDesc(),
		Status:      uint64(convertReportStatus(in.GetIsDraft())),
	}

	// 开事务创建工作汇报
	err = l.svcCtx.ReportModel.TransCtx(l.ctx, func(ctx context.Context, s sqlx.Session) error {
		// 创建x_report
		r, err := l.svcCtx.ReportModel.TransInsertCtx(ctx, s, rpt)
		if err != nil {
			logx.Errorw("failed to create x_report on CreateReport.ReportModel.TransInsertCtx", logx.Field("error", err))
			return err
		}
		id, err := r.LastInsertId()
		if err != nil {
			logx.Errorw("failed to get x_report LastInsertId", logx.Field("error", err))
			return err
		}
		rpt.Id = uint64(id)

		// 创建x_report_item
		inputReportItemList := in.GetReportItemList()
		items := make([]*report.XReportItem, len(inputReportItemList))
		var mentions []*report.XReportMention
		for i, v := range inputReportItemList {
			items[i] = &report.XReportItem{
				ReportId: rpt.Id,
				Content:  v.Content,
			}
		}

		for i, v := range items {
			result, err := l.svcCtx.ReportItemModel.TransInsertCtx(ctx, s, v)
			if err != nil {
				logx.Errorw("failed to create x_report_item on CreateReport.ReportItemModel.TransInsertCtx", logx.Field("error", err))
				return err
			}
			itemId, err := result.LastInsertId()
			if err != nil {
				logx.Errorw("failed to get x_report_item LastInsertId on Create.ReportItemModel.TransInsertCtx", logx.Field("error", err))
				return err
			}
			items[i].Id = uint64(itemId)

			// 构建x_report_mention
			for _, m := range inputReportItemList[i].MentionList {
				mentions = append(mentions, &report.XReportMention{
					ReportId:         rpt.Id,
					ReportType:       rpt.ReportType,
					ReportDate:       rpt.ReportDate,
					ReportItemId:     uint64(itemId),
					InitiatorEmail:   in.GetUserEmail(),
					MentionUserEmail: m,
				})
			}
		}

		// 创建x_report_mention
		for _, m := range mentions {
			result2, err := l.svcCtx.ReportMentionModel.TransInsertCtx(ctx, s, m)
			if err != nil {
				logx.Errorw("failed to create x_report_mention on CreateReport.ReportMentionModel.TransInsertCtx", logx.Field("error", err))
				return err
			}
			_, err = result2.LastInsertId()
			if err != nil {
				logx.Errorw("failed to get x_report_mention LastInsertId on CreateReport.ReportMentionModel.TransInsertCtx", logx.Field("error", err))
				return err
			}
		}

		return nil
	})
	if err != nil {
		return nil, ErrReportCreateFailed
	}

	return &oa.CreateReportResp{ReportId: rpt.Id}, nil
}

func (l *CreateReportLogic) Validate(in *oa.CreateReportReq) error {
	if in.GetUserEmail() == "" {
		return ErrUserEmailRequired
	}
	if in.GetReportDate() == "" {
		return status.Error(codes.InvalidArgument, "report_date参数必填")
	}
	return nil
}
