package html
import (
	"bytes"
	"fmt"
	"io"
	"strings"
	"github.com/tdewolff/parse/v2"
	"github.com/tdewolff/parse/v2/css"
)
type AST struct {
	Children []*Tag
	Text     []byte
}
func (ast *AST) String() string {
	sb := strings.Builder{}
	for i, child := range ast.Children {
		if i != 0 {
			sb.WriteString("\n")
		}
		sb.WriteString(child.ASTString())
	}
	return sb.String()
}
type Attr struct {
	Key, Val []byte
}
func (attr *Attr) String() string {
	return fmt.Sprintf(`%s="%s"`, string(attr.Key), string(attr.Val))
}
type Tag struct {
	Root       *AST
	Parent     *Tag
	Prev, Next *Tag
	Children   []*Tag
	Index      int
	Name               []byte
	Attrs              []Attr
	textStart, textEnd int
}
func (tag *Tag) getAttr(key []byte) ([]byte, bool) {
	for _, attr := range tag.Attrs {
		if bytes.Equal(key, attr.Key) {
			return attr.Val, true
		}
	}
	return nil, false
}
func (tag *Tag) GetAttr(key string) (string, bool) {
	val, ok := tag.getAttr([]byte(key))
	return string(val), ok
}
func (tag *Tag) Text() string {
	return string(tag.Root.Text[tag.textStart:tag.textEnd])
}
func (tag *Tag) String() string {
	sb := strings.Builder{}
	sb.WriteString("<")
	sb.Write(tag.Name)
	for _, attr := range tag.Attrs {
		sb.WriteString(" ")
		sb.WriteString(attr.String())
	}
	sb.WriteString(">")
	return sb.String()
}
func (tag *Tag) ASTString() string {
	sb := strings.Builder{}
	sb.WriteString(tag.String())
	for _, child := range tag.Children {
		sb.WriteString("\n  ")
		s := child.ASTString()
		s = strings.ReplaceAll(s, "\n", "\n  ")
		sb.WriteString(s)
	}
	return sb.String()
}
func Parse(r *parse.Input) (*AST, error) {
	ast := &AST{}
	root := &Tag{}
	cur := root
	l := NewLexer(r)
	for {
		tt, data := l.Next()
		switch tt {
		case ErrorToken:
			if err := l.Err(); err != io.EOF {
				return nil, err
			}
			ast.Children = root.Children
			return ast, nil
		case TextToken:
			ast.Text = append(ast.Text, data...)
		case StartTagToken:
			child := &Tag{
				Root:      ast,
				Parent:    cur,
				Index:     len(cur.Children),
				Name:      l.Text(),
				textStart: len(ast.Text),
			}
			if 0 < len(cur.Children) {
				child.Prev = cur.Children[len(cur.Children)-1]
				child.Prev.Next = child
			}
			cur.Children = append(cur.Children, child)
			cur = child
		case AttributeToken:
			val := l.AttrVal()
			if 0 < len(val) && (val[0] == '"' || val[0] == '\'') {
				val = val[1 : len(val)-1]
			}
			cur.Attrs = append(cur.Attrs, Attr{l.AttrKey(), val})
		case StartTagCloseToken:
			if voidTags[string(cur.Name)] {
				cur.textEnd = len(ast.Text)
				cur = cur.Parent
			}
		case EndTagToken, StartTagVoidToken:
			start := cur
			for start != root && !bytes.Equal(l.Text(), start.Name) {
				start = start.Parent
			}
			if start == root {
				// ignore
			} else {
				parent := start.Parent
				for cur != parent {
					cur.textEnd = len(ast.Text)
					cur = cur.Parent
				}
			}
		}
	}
}
func (ast *AST) Query(s string) (*Tag, error) {
	sel, err := ParseSelector(s)
	if err != nil {
		return nil, err
	}
	for _, child := range ast.Children {
		if match := child.query(sel); match != nil {
			return match, nil
		}
	}
	return nil, nil
}
func (tag *Tag) query(sel selector) *Tag {
	if sel.AppliesTo(tag) {
		return tag
	}
	for _, child := range tag.Children {
		if match := child.query(sel); match != nil {
			return match
		}
	}
	return nil
}
func (ast *AST) QueryAll(s string) ([]*Tag, error) {
	sel, err := ParseSelector(s)
	if err != nil {
		return nil, err
	}
	matches := []*Tag{}
	for _, child := range ast.Children {
		child.queryAll(&matches, sel)
	}
	return matches, nil
}
func (tag *Tag) queryAll(matches *[]*Tag, sel selector) {
	if sel.AppliesTo(tag) {
		*matches = append(*matches, tag)
	}
	for _, child := range tag.Children {
		child.queryAll(matches, sel)
	}
}
type attrSelector struct {
	op   byte // empty, =, ~, |
	attr []byte
	val  []byte
}
func (sel attrSelector) AppliesTo(tag *Tag) bool {
	val, ok := tag.getAttr(sel.attr)
	if !ok {
		return false
	}
	switch sel.op {
	case 0:
		return true
	case '=':
		return bytes.Equal(val, sel.val)
	case '~':
		if 0 < len(sel.val) {
			vals := bytes.Split(val, []byte(" "))
			for _, val := range vals {
				if bytes.Equal(val, sel.val) {
					return true
				}
			}
		}
	case '|':
		return bytes.Equal(val, sel.val) || bytes.HasPrefix(val, append(sel.val, '-'))
	}
	return false
}
func (attr attrSelector) String() string {
	sb := strings.Builder{}
	sb.Write(attr.attr)
	if attr.op != 0 {
		sb.WriteByte(attr.op)
		if attr.op != '=' {
			sb.WriteByte('=')
		}
		sb.WriteByte('"')
		sb.Write(attr.val)
		sb.WriteByte('"')
	}
	return sb.String()
}
type selectorNode struct {
	typ   []byte // is * for universal
	attrs []attrSelector
	op    byte // space or >, last is NULL
}
func (sel selectorNode) AppliesTo(tag *Tag) bool {
	if 0 < len(sel.typ) && !bytes.Equal(sel.typ, []byte("*")) && !bytes.Equal(sel.typ, tag.Name) {
		return false
	}
	for _, attr := range sel.attrs {
		if !attr.AppliesTo(tag) {
			return false
		}
	}
	return true
}
func (sel selectorNode) String() string {
	sb := strings.Builder{}
	sb.Write(sel.typ)
	for _, attr := range sel.attrs {
		if bytes.Equal(attr.attr, []byte("id")) && attr.op == '=' {
			sb.WriteByte('#')
			sb.Write(attr.val)
		} else if bytes.Equal(attr.attr, []byte("class")) && attr.op == '~' {
			sb.WriteByte('.')
			sb.Write(attr.val)
		} else {
			sb.WriteByte('[')
			sb.WriteString(attr.String())
			sb.WriteByte(']')
		}
	}
	if sel.op != 0 {
		sb.WriteByte(' ')
		sb.WriteByte(sel.op)
		sb.WriteByte(' ')
	}
	return sb.String()
}
type token struct {
	tt   css.TokenType
	data []byte
}
type selector []selectorNode
func ParseSelector(s string) (selector, error) {
	ts := []token{}
	l := css.NewLexer(parse.NewInputString(s))
	for {
		tt, data := l.Next()
		if tt == css.ErrorToken {
			if err := l.Err(); err != io.EOF {
				return selector{}, err
			}
			break
		}
		ts = append(ts, token{
			tt:   tt,
			data: data,
		})
	}
	sel := selector{}
	node := selectorNode{}
	for i := 0; i < len(ts); i++ {
		t := ts[i]
		if 0 < i && (t.tt == css.WhitespaceToken || t.tt == css.DelimToken && t.data[0] == '>') {
			if t.tt == css.DelimToken {
				node.op = '>'
			} else {
				node.op = ' '
			}
			sel = append(sel, node)
			node = selectorNode{}
		} else if t.tt == css.IdentToken || t.tt == css.DelimToken && t.data[0] == '*' {
			node.typ = t.data
		} else if t.tt == css.DelimToken && (t.data[0] == '.' || t.data[0] == '#') && i+1 < len(ts) && ts[i+1].tt == css.IdentToken {
			if t.data[0] == '#' {
				node.attrs = append(node.attrs, attrSelector{op: '=', attr: []byte("id"), val: ts[i+1].data})
			} else {
				node.attrs = append(node.attrs, attrSelector{op: '~', attr: []byte("class"), val: ts[i+1].data})
			}
			i++
		} else if t.tt == css.DelimToken && t.data[0] == '[' && i+2 < len(ts) && ts[i+1].tt == css.IdentToken && ts[i+2].tt == css.DelimToken {
			if ts[i+2].data[0] == ']' {
				node.attrs = append(node.attrs, attrSelector{op: 0, attr: ts[i+1].data})
				i += 2
			} else if i+4 < len(ts) && ts[i+3].tt == css.IdentToken && ts[i+4].tt == css.DelimToken && ts[i+4].data[0] == ']' {
				node.attrs = append(node.attrs, attrSelector{op: ts[i+2].data[0], attr: ts[i+1].data, val: ts[i+3].data})
				i += 4
			}
		}
	}
	sel = append(sel, node)
	return sel, nil
}
func (sels selector) AppliesTo(tag *Tag) bool {
	if len(sels) == 0 {
		return true
	} else if !sels[len(sels)-1].AppliesTo(tag) {
		return false
	}
	tag = tag.Parent
	isel := len(sels) - 2
	for 0 <= isel && tag != nil {
		switch sels[isel].op {
		case ' ':
			for tag != nil {
				if sels[isel].AppliesTo(tag) {
					break
				}
				tag = tag.Parent
			}
		case '>':
			if !sels[isel].AppliesTo(tag) {
				return false
			}
			tag = tag.Parent
		default:
			return false
		}
		isel--
	}
	return len(sels) != 0 && isel == -1
}
func (sels selector) String() string {
	if len(sels) == 0 {
		return ""
	}
	sb := strings.Builder{}
	for _, sel := range sels {
		sb.WriteString(sel.String())
	}
	return sb.String()[1:]
}
var voidTags = map[string]bool{
	"area":   true,
	"base":   true,
	"br":     true,
	"col":    true,
	"embed":  true,
	"hr":     true,
	"img":    true,
	"input":  true,
	"link":   true,
	"meta":   true,
	"source": true,
	"track":  true,
	"wbr":    true,
}