package period

import (
	"testing"
)

func TestGetPeriodList(t *testing.T) {
	tests := []struct {
		name       string
		periodType int
		dateString string
		count      int
		direction  PeriodDirection
		want       []string
		wantErr    bool
	}{
		{
			name:       "周-向前",
			periodType: 1,
			dateString: "20250101-20250107",
			count:      3,
			direction:  PeriodDirectionForward,
			want:       []string{"20250101-20250107", "20241225-20241231", "20241218-20241224"},
			wantErr:    false,
		},
		{
			name:       "月-向前",
			periodType: 2,
			dateString: "20250101-20250131",
			count:      3,
			direction:  PeriodDirectionForward,
			want:       []string{"20250101-20250131", "20241201-20241231", "20241101-20241130"},
			wantErr:    false,
		},
		{
			name:       "双月-向前",
			periodType: 3,
			dateString: "20250301-20250430", // 3-4月双月
			count:      3,
			direction:  PeriodDirectionForward,
			want:       []string{"20250301-20250430", "20250101-20250228", "20241101-20241231"},
			wantErr:    false,
		},
		{
			name:       "季度-向前",
			periodType: 4,
			dateString: "20250101-20250331", // 第一季度
			count:      3,
			direction:  PeriodDirectionForward,
			want:       []string{"20250101-20250331", "20241001-20241231", "20240701-20240930"},
			wantErr:    false,
		},
		{
			name:       "双月-向前-第二个案例",
			periodType: 3,
			dateString: "20250301-20250430", // 3-4月双月
			count:      5,
			direction:  PeriodDirectionForward,
			want:       []string{"20250301-20250430", "20250101-20250228", "20241101-20241231", "20240901-20241031", "20240701-20240831"},
			wantErr:    false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := GetPeriodList(tt.periodType, tt.dateString, tt.count, tt.direction)
			if (err != nil) != tt.wantErr {
				t.Errorf("GetPeriodList() error = %v, wantErr %v", err, tt.wantErr)
				return
			}

			if len(got) != len(tt.want) {
				t.Errorf("GetPeriodList() 长度不匹配 got = %v, want %v", got, tt.want)
				return
			}

			for i, period := range got {
				if period != tt.want[i] {
					t.Errorf("GetPeriodList() 在位置%d: got = %v, want %v", i, period, tt.want[i])
				}
			}
		})
	}
}
