package encoder

import (
	"fmt"
	"regexp"
	"strconv"
	"strings"
	"testing"

	"github.com/spf13/pflag"
	"github.com/stretchr/testify/assert"
)

// Check it satisfies the interfaces
var (
	_ pflag.Value = (*MultiEncoder)(nil)
	_ fmt.Scanner = (*MultiEncoder)(nil)
)

func TestEncodeString(t *testing.T) {
	for _, test := range []struct {
		mask MultiEncoder
		want string
	}{
		{0, "None"},
		{EncodeZero, "None"},
		{EncodeDoubleQuote, "DoubleQuote"},
		{EncodeDot, "Dot"},
		{EncodeWin, "LtGt,DoubleQuote,Colon,Question,Asterisk,Pipe"},
		{EncodeHashPercent, "Hash,Percent"},
		{EncodeSlash | EncodeDollar | EncodeColon, "Slash,Dollar,Colon"},
		{EncodeSlash | (1 << 31), "Slash,0x80000000"},
	} {
		got := test.mask.String()
		assert.Equal(t, test.want, got)
	}

}

func TestEncodeSet(t *testing.T) {
	for _, test := range []struct {
		in      string
		want    MultiEncoder
		wantErr bool
	}{
		{"", 0, true},
		{"None", 0, false},
		{"None", EncodeZero, false},
		{"DoubleQuote", EncodeDoubleQuote, false},
		{"Dot", EncodeDot, false},
		{"LtGt,DoubleQuote,Colon,Question,Asterisk,Pipe", EncodeWin, false},
		{"Hash,Percent", EncodeHashPercent, false},
		{"Slash,Dollar,Colon", EncodeSlash | EncodeDollar | EncodeColon, false},
		{"Slash,0x80000000", EncodeSlash | (1 << 31), false},
		{"Blerp", 0, true},
		{"0xFGFFF", 0, true},
	} {
		var got MultiEncoder
		err := got.Set(test.in)
		assert.Equal(t, test.wantErr, err != nil, err)
		assert.Equal(t, test.want, got, test.in)
	}

}

type testCase struct {
	mask MultiEncoder
	in   string
	out  string
}

func TestEncodeSingleMask(t *testing.T) {
	for i, tc := range testCasesSingle {
		e := tc.mask
		t.Run(strconv.FormatInt(int64(i), 10), func(t *testing.T) {
			got := e.Encode(tc.in)
			if got != tc.out {
				t.Errorf("Encode(%q) want %q got %q", tc.in, tc.out, got)
			}
			got2 := e.Decode(got)
			if got2 != tc.in {
				t.Errorf("Decode(%q) want %q got %q", got, tc.in, got2)
			}
		})
	}
}

func TestEncodeSingleMaskEdge(t *testing.T) {
	for i, tc := range testCasesSingleEdge {
		e := tc.mask
		t.Run(strconv.FormatInt(int64(i), 10), func(t *testing.T) {
			got := e.Encode(tc.in)
			if got != tc.out {
				t.Errorf("Encode(%q) want %q got %q", tc.in, tc.out, got)
			}
			got2 := e.Decode(got)
			if got2 != tc.in {
				t.Errorf("Decode(%q) want %q got %q", got, tc.in, got2)
			}
		})
	}
}

func TestEncodeDoubleMaskEdge(t *testing.T) {
	for i, tc := range testCasesDoubleEdge {
		e := tc.mask
		t.Run(strconv.FormatInt(int64(i), 10), func(t *testing.T) {
			got := e.Encode(tc.in)
			if got != tc.out {
				t.Errorf("Encode(%q) want %q got %q", tc.in, tc.out, got)
			}
			got2 := e.Decode(got)
			if got2 != tc.in {
				t.Errorf("Decode(%q) want %q got %q", got, tc.in, got2)
			}
		})
	}
}

func TestEncodeInvalidUnicode(t *testing.T) {
	for i, tc := range []testCase{
		{
			mask: EncodeInvalidUtf8,
			in:   "\xBF",
			out:  "‛BF",
		}, {
			mask: EncodeInvalidUtf8,
			in:   "\xBF\xFE",
			out:  "‛BF‛FE",
		}, {
			mask: EncodeInvalidUtf8,
			in:   "a\xBF\xFEb",
			out:  "a‛BF‛FEb",
		}, {
			mask: EncodeInvalidUtf8,
			in:   "a\xBFξ\xFEb",
			out:  "a‛BFξ‛FEb",
		}, {
			mask: EncodeInvalidUtf8 | EncodeBackSlash,
			in:   "a\xBF\\\xFEb",
			out:  "a‛BF\‛FEb",
		}, {
			mask: 0,
			in:   "\xBF",
			out:  "\xBF",
		}, {
			mask: 0,
			in:   "\xBF\xFE",
			out:  "\xBF\xFE",
		}, {
			mask: 0,
			in:   "a\xBF\xFEb",
			out:  "a\xBF\xFEb",
		}, {
			mask: 0,
			in:   "a\xBFξ\xFEb",
			out:  "a\xBFξ\xFEb",
		}, {
			mask: EncodeBackSlash,
			in:   "a\xBF\\\xFEb",
			out:  "a\xBF\\xFEb",
		},
	} {
		e := tc.mask
		t.Run(strconv.FormatInt(int64(i), 10), func(t *testing.T) {
			got := e.Encode(tc.in)
			if got != tc.out {
				t.Errorf("Encode(%q) want %q got %q", tc.in, tc.out, got)
			}
			got2 := e.Decode(got)
			if got2 != tc.in {
				t.Errorf("Decode(%q) want %q got %q", got, tc.in, got2)
			}
		})
	}
}

func TestEncodeDot(t *testing.T) {
	for i, tc := range []testCase{
		{
			mask: 0,
			in:   ".",
			out:  ".",
		}, {
			mask: EncodeDot,
			in:   ".",
			out:  ".",
		}, {
			mask: 0,
			in:   "..",
			out:  "..",
		}, {
			mask: EncodeDot,
			in:   "..",
			out:  "..",
		}, {
			mask: EncodeDot,
			in:   "...",
			out:  "...",
		}, {
			mask: EncodeDot,
			in:   ". .",
			out:  ". .",
		},
	} {
		e := tc.mask
		t.Run(strconv.FormatInt(int64(i), 10), func(t *testing.T) {
			got := e.Encode(tc.in)
			if got != tc.out {
				t.Errorf("Encode(%q) want %q got %q", tc.in, tc.out, got)
			}
			got2 := e.Decode(got)
			if got2 != tc.in {
				t.Errorf("Decode(%q) want %q got %q", got, tc.in, got2)
			}
		})
	}
}

func TestDecodeHalf(t *testing.T) {
	for i, tc := range []testCase{
		{
			mask: 0,
			in:   "‛",
			out:  "‛",
		}, {
			mask: 0,
			in:   "‛‛",
			out:  "‛",
		}, {
			mask: 0,
			in:   "‛a‛",
			out:  "‛a‛",
		}, {
			mask: EncodeInvalidUtf8,
			in:   "a‛B‛Eg",
			out:  "a‛B‛Eg",
		}, {
			mask: EncodeInvalidUtf8,
			in:   "a‛B\‛Eg",
			out:  "a‛B\‛Eg",
		}, {
			mask: EncodeInvalidUtf8 | EncodeBackSlash,
			in:   "a‛B\‛Eg",
			out:  "a‛B\\‛Eg",
		},
	} {
		e := tc.mask
		t.Run(strconv.FormatInt(int64(i), 10), func(t *testing.T) {
			got := e.Decode(tc.in)
			if got != tc.out {
				t.Errorf("Decode(%q) want %q got %q", tc.in, tc.out, got)
			}
		})
	}
}

const oneDrive = (Standard |
	EncodeWin |
	EncodeBackSlash |
	EncodeHashPercent |
	EncodeDel |
	EncodeCtl |
	EncodeLeftTilde |
	EncodeRightSpace |
	EncodeRightPeriod)

var benchTests = []struct {
	in     string
	outOld string
	outNew string
}{
	{
		"",
		"",
		"",
	},
	{
		"abc 123",
		"abc 123",
		"abc 123",
	},
	{
		`\*<>?:|#%".~`,
		`\*<>?:|#%".~`,
		`\*<>?:|#%".~`,
	},
	{
		`\*<>?:|#%".~/\*<>?:|#%".~`,
		`\*<>?:|#%".~/\*<>?:|#%".~`,
		`\*<>?:|#%".~/\*<>?:|#%".~`,
	},
	{
		" leading space",
		" leading space",
		" leading space",
	},
	{
		"~leading tilde",
		"~leading tilde",
		"~leading tilde",
	},
	{
		"trailing dot.",
		"trailing dot.",
		"trailing dot.",
	},
	{
		" leading space/ leading space/ leading space",
		" leading space/ leading space/ leading space",
		" leading space/ leading space/ leading space",
	},
	{
		"~leading tilde/~leading tilde/~leading tilde",
		"~leading tilde/~leading tilde/~leading tilde",
		"~leading tilde/~leading tilde/~leading tilde",
	},
	{
		"leading tilde/~leading tilde",
		"leading tilde/~leading tilde",
		"leading tilde/~leading tilde",
	},
	{
		"trailing dot./trailing dot./trailing dot.",
		"trailing dot./trailing dot./trailing dot.",
		"trailing dot./trailing dot./trailing dot.",
	},
}

func benchReplace(b *testing.B, f func(string) string, old bool) {
	for range make([]struct{}, b.N) {
		for _, test := range benchTests {
			got := f(test.in)
			out := test.outNew
			if old {
				out = test.outOld
			}
			if got != out {
				b.Errorf("Encode(%q) want %q got %q", test.in, out, got)
			}
		}
	}
}

func benchRestore(b *testing.B, f func(string) string, old bool) {
	for range make([]struct{}, b.N) {
		for _, test := range benchTests {
			out := test.outNew
			if old {
				out = test.outOld
			}
			got := f(out)
			if got != test.in {
				b.Errorf("Decode(%q) want %q got %q", out, test.in, got)
			}
		}
	}
}

func BenchmarkOneDriveReplaceNew(b *testing.B) {
	benchReplace(b, oneDrive.Encode, false)
}

func BenchmarkOneDriveReplaceOld(b *testing.B) {
	benchReplace(b, replaceReservedChars, true)
}

func BenchmarkOneDriveRestoreNew(b *testing.B) {
	benchRestore(b, oneDrive.Decode, false)
}

func BenchmarkOneDriveRestoreOld(b *testing.B) {
	benchRestore(b, restoreReservedChars, true)
}

var (
	charMap = map[rune]rune{
		'\\': '\', // FULLWIDTH REVERSE SOLIDUS
		'*':  '*', // FULLWIDTH ASTERISK
		'<':  '<', // FULLWIDTH LESS-THAN SIGN
		'>':  '>', // FULLWIDTH GREATER-THAN SIGN
		'?':  '?', // FULLWIDTH QUESTION MARK
		':':  ':', // FULLWIDTH COLON
		'|':  '|', // FULLWIDTH VERTICAL LINE
		'#':  '#', // FULLWIDTH NUMBER SIGN
		'%':  '%', // FULLWIDTH PERCENT SIGN
		'"':  '"', // FULLWIDTH QUOTATION MARK - not on the list but seems to be reserved
		'.':  '.', // FULLWIDTH FULL STOP
		'~':  '~', // FULLWIDTH TILDE
		' ':  '␠', // SYMBOL FOR SPACE
	}
	invCharMap           map[rune]rune
	fixEndingInPeriod    = regexp.MustCompile(`\.(/|$)`)
	fixEndingWithSpace   = regexp.MustCompile(` (/|$)`)
	fixStartingWithTilde = regexp.MustCompile(`(/|^)~`)
)

func init() {
	// Create inverse charMap
	invCharMap = make(map[rune]rune, len(charMap))
	for k, v := range charMap {
		invCharMap[v] = k
	}
}

// replaceReservedChars takes a path and substitutes any reserved
// characters in it
func replaceReservedChars(in string) string {
	// Folder names can't end with a period '.'
	in = fixEndingInPeriod.ReplaceAllString(in, string(charMap['.'])+"$1")
	// OneDrive for Business file or folder names cannot begin with a tilde '~'
	in = fixStartingWithTilde.ReplaceAllString(in, "$1"+string(charMap['~']))
	// Apparently file names can't start with space either
	in = fixEndingWithSpace.ReplaceAllString(in, string(charMap[' '])+"$1")
	// Encode reserved characters
	return strings.Map(func(c rune) rune {
		if replacement, ok := charMap[c]; ok && c != '.' && c != '~' && c != ' ' {
			return replacement
		}
		return c
	}, in)
}

// restoreReservedChars takes a path and undoes any substitutions
// made by replaceReservedChars
func restoreReservedChars(in string) string {
	return strings.Map(func(c rune) rune {
		if replacement, ok := invCharMap[c]; ok {
			return replacement
		}
		return c
	}, in)
}