// compiler/compiler.go

package compiler

import (
//	"fmt"
	"SB/ast"
	"SB/object"
	"SB/CMP/code"
)

type TBytecode struct {
	Instructions code.Instructions
	Constants []object.IObject
}

type TEmittedInstruction struct{
	Opcode code.Opcode
	Position int
}

type TCompilationScope struct {
	instructions				code.Instructions
	lastInstruction			TEmittedInstruction
	previousInstruction TEmittedInstruction
}


type TCompiler struct {
	symbolTable			*TSymbolTable
	constants 			[]object.IObject
	scopes					[]TCompilationScope
	scopeIndex			int
}


func New() *TCompiler {
	mainScope := TCompilationScope {
		instructions: 				code.Instructions{},
		lastInstruction:			TEmittedInstruction{},
		previousInstruction:	TEmittedInstruction{},
	}
	
	symbolTable := NewSymbolTable()
	for i, v := range object.Builtins {
		symbolTable.DefineBuiltin(i, v.Name)
	}
	
	return &TCompiler {
		symbolTable: symbolTable,
		constants: []object.IObject{},
		scopes: []TCompilationScope{ mainScope },
		scopeIndex: 0,
	}
}


func (c *TCompiler) currentInstructions() code.Instructions {
	return c.scopes[c.scopeIndex].instructions
}


func (c *TCompiler) Bytecode() *TBytecode {
	return &TBytecode	{
		Instructions: c.currentInstructions(),
		Constants: c.constants,
	}
}


func (c *TCompiler) addConstant(obj object.IObject) int {
	c.constants = append(c.constants, obj)
	return len(c.constants) - 1
}


func (c *TCompiler) emit(op code.Opcode, operands ...int) int{
	ins := code.Make(op, operands...)
	pos := c.addInstruction(ins)
	c.setLastInstruction(op, pos)
	
	return pos
}


func (c *TCompiler) addInstruction(ins []byte) int {
	posNewInstruction := len(c.currentInstructions())
	updatedInstructions := append(c.currentInstructions(), ins...)
	
	c.scopes[c.scopeIndex].instructions = updatedInstructions
	
	return posNewInstruction
}


func (c *TCompiler) setLastInstruction(op code.Opcode, pos int) {
	prev := c.scopes[c.scopeIndex].lastInstruction
	last := TEmittedInstruction{ Opcode: op, Position: pos}
	
	c.scopes[c.scopeIndex].previousInstruction = prev
	c.scopes[c.scopeIndex].lastInstruction = last
}


func (c *TCompiler) lastInstructionIs(op code.Opcode) bool {
	if len(c.currentInstructions()) == 0 {
		return false
	}
	
	return c.scopes[c.scopeIndex].lastInstruction.Opcode == op
}


func (c *TCompiler) removeLastPop() {
	last := c.scopes[c.scopeIndex].lastInstruction
	prev := c.scopes[c.scopeIndex].previousInstruction

	old := c.currentInstructions()
	new:= old[:last.Position]

	c.scopes[c.scopeIndex].instructions = new
	c.scopes[c.scopeIndex].lastInstruction = prev
}


func (c *TCompiler) replaceInstruction(pos int, newInstruction []byte) {
	ins := c.currentInstructions()

	for i := 0; i < len(newInstruction); i++ {
		ins[pos + i] = newInstruction[i]
	}
}


func (c *TCompiler) changeOperand(opPos int, operand int) {
	op := code.Opcode(c.currentInstructions()[opPos])
	newInstruction := code.Make(op, operand)
	c.replaceInstruction(opPos, newInstruction)
}


// TODO: for REPL test only
func NewWithState(s *TSymbolTable, constants []object.IObject) *TCompiler {
	compiler := New()
	compiler.symbolTable = s
	compiler.constants = constants
	return compiler
}


func (c *TCompiler) enterScope() {
	scope := TCompilationScope {
		instructions: 				code.Instructions{},
		lastInstruction: 			TEmittedInstruction{},
		previousInstruction:	TEmittedInstruction{},
	}
	c.scopes = append(c.scopes, scope)
	c.scopeIndex++
	c.symbolTable = NewEnclosedSymbolTable(c.symbolTable)
}

func (c *TCompiler) leaveScope() code.Instructions {
	instructions := c.currentInstructions()

	c.scopes = c.scopes[:len(c.scopes) - 1]
	c.scopeIndex--
	c.symbolTable = c.symbolTable.Outer
	
	return instructions
}


func (c *TCompiler) loadSymbol(s TSymbol) {
	switch s.Scope {
	case GlobalScope:
		c.emit(code.OP_GETGL, s.Index)
	case LocalScope:
		c.emit(code.OP_GETLC, s.Index)
	case BuiltinScope:
		c.emit(code.OP_GETBU, s.Index)
	}
}


func (c *TCompiler) getBuiltinType(node ast.INode) (object.IObject, bool) {
	switch node := node.(type) {
	case *ast.TCallExpression:
		// should be identifier
		name, ok := node.Routine.(*ast.TIdentifier)
		if !ok {
			return object.ERR_INVALID_FUNCTION_DATA_TYPE(), false
		}
		// should be in symbol table
		symbol, ok := c.symbolTable.Resolve(name.Value)
		if !ok {
			return object.ERR_UNKNOWN_IDENTIFIER(symbol.Name), false
		}
		// should be builtin (for now)
		if symbol.Scope != BuiltinScope {
			return object.ERR_INVALID_DATA_TYPE(symbol.Name), false
		}
		
		// fetch type
		switch symbol.Index {
		case 0:		// nil
			return object.NULL, true
		case 1:		// boolean
			return object.FALSE, true
		case 2:		// byte
			return object.BYTE, true
		case 3:		// double
			return object.DOUBLE, true
		case 4:		// single
			return object.SINGLE, true
		case 5:		// integer
			return object.INTEGER, true
		case 6:		// long
			return object.LONG, true
		case 7:		// short
			return object.SHORT, true
		case 8:		// string
			return object.STRING, true
		case 9:		// ubyte
			return object.UBYTE, true
		case 10:	// uinteger
			return object.UINTEGER, true
		case 11:	// ulong
			return object.ULONG, true
		case 12:	// ushort
			return object.USHORT, true
		case 13:	// func
			return object.FUNCTION, true
		//case 14:	// array
		//case 15:	// list
		//case 16:	// map
		default:
			return object.ERR_INVALID_DATA_TYPE(symbol.Name), false
		}
	}
	return object.ERR_INVALID_DATA_TYPE(""), false
}


func (c *TCompiler) Compile(node ast.INode) object.IObject {
	switch node := node.(type) {
		
	case *ast.TProgram:
		for _, s := range node.Statements {
			err := c.Compile(s)
			if err != nil {
				return err
			}
		}
	
	case *ast.TExpressionStatement:
		err := c.Compile(node.Expression)
		if err != nil {
			return err
		}
	
	case *ast.TVarStatement:
		var value ast.IExpression
		
		for i := 0; i < len(node.Names); i++ {
			if i < len(node.Values) {
				value = node.Values[i]
			}
			if value == nil {
				c.emit(code.OP_NULL)
			} else {
				err := c.Compile(value)
				if err != nil {
					return err
				}
			}
			symbol := c.symbolTable.Define(node.Names[i].Value)
			if symbol.Scope == GlobalScope {
				c.emit(code.OP_DEFGL, symbol.Index)
			} else {
				c.emit(code.OP_DEFLC, symbol.Index)
			}
		}
	
	case *ast.TAssignExpression:
		err := c.Compile(node.Value)
		if err != nil {
			return err
		}
		var name string
		switch n := node.Name.(type) {
		case *ast.TIdentifier:
			name = n.Value
		}
		symbol, ok := c.symbolTable.Resolve(name)
		if !ok {
			return object.ERR_UNKNOWN_IDENTIFIER(name)
		}
		if symbol.Scope == GlobalScope {
			c.emit(code.OP_MOVGL, symbol.Index)
		} else {
			c.emit(code.OP_MOVLC, symbol.Index)
		}
		
	case *ast.TIdentifier:
		symbol, ok := c.symbolTable.Resolve(node.Value)
		if !ok {
			return object.ERR_UNKNOWN_IDENTIFIER(symbol.Name)
		}
		c.loadSymbol(symbol)
		
	case *ast.TInfixExpression:
		switch node.Operator {
		case "<":
			err := c.Compile(node.Right)
			if err != nil {
				return err
			}
			err = c.Compile(node.Left)
			if err != nil {
				return err
			}
			c.emit(code.OP_GRE)
			return nil
		case "<=":
			err := c.Compile(node.Right)
			if err != nil {
				return err
			}
			err = c.Compile(node.Left)
			if err != nil {
				return err
			}
			c.emit(code.OP_GEQ)
			return nil
		}
		
		err := c.Compile(node.Left)
		if err != nil {
			return err
		}
		
		err = c.Compile(node.Right)
		if err != nil {
			return err
		}
		
		switch node.Operator {
		
		case "+":
			c.emit(code.OP_ADD)
		
		case "-":
			c.emit(code.OP_SUB)
		
		case "*":
			c.emit(code.OP_MUL)
		
		case "/":
			c.emit(code.OP_DIV)
		
		case "%":
			c.emit(code.OP_MOD)
		
		case "^":
			c.emit(code.OP_POW)
		
		case "==":
			c.emit(code.OP_EQU)
		
		case "<>":
			c.emit(code.OP_NEQ)
		
		case ">":
			c.emit(code.OP_GRE)
		
		case ">=":
			c.emit(code.OP_GEQ)
			
		case "and":
			c.emit(code.OP_AND)
		
		case "nor":
			c.emit(code.OP_NOR)
		
		case "or":
			c.emit(code.OP_OR)
		
		case "xor":
			c.emit(code.OP_XOR)
			
		default:
			return object.ERR_INVALID_OPERATOR(node.Operator)
		}
		
	case *ast.TIntegerLiteral:
		obj := &object.TInteger { Value: node.Value }
		c.emit(code.OP_CONST, c.addConstant(obj))
	
	case *ast.TFloatLiteral:
		obj := &object.TDouble { Value: node.Value }
		c.emit(code.OP_CONST, c.addConstant(obj))
	
	case *ast.TStringLiteral:
		obj := &object.TString { Value: node.Value }
		c.emit(code.OP_CONST, c.addConstant(obj))
			
	case *ast.TBoolean:
		if node.Value {
			c.emit(code.OP_TRUE)
		} else {
			c.emit(code.OP_FALSE)
		}
		
	case *ast.TPrefixExpression:
		err := c.Compile(node.Right)
		if err != nil {
			return err
		}

		switch node.Operator {
		case "is":
			c.emit(code.OP_IS)
			
		case "not":
			c.emit(code.OP_NOT)
		
		case "-":
			c.emit(code.OP_MIN)
		
		default:
			return object.ERR_INVALID_OPERATOR(node.Operator)
		}
	
	case *ast.TIfStatement:
		jmpPos := make([]int, len(node.Conditions))
		
		for i, n := range node.Conditions {
			err := c.Compile(n.Condition)
			if err != nil {
				return err
			}
			
			// jump to next condition, default or passed if block
			jntPos :=  c.emit(code.OP_JNT, code.UNKNOWN_POS)
			
			err = c.Compile(n.Body)
			if err != nil {
				return err
			}
			
			if c.lastInstructionIs(code.OP_POP) {
				c.removeLastPop()
			}
			
			// jump passed if block
			jmpPos[i] = c.emit(code.OP_JMP, code.UNKNOWN_POS)
			
			// position of next condition
			truePos := len(c.currentInstructions())
			c.changeOperand(jntPos, truePos)
		}
		
		// final else block (if any)
		if node.Default == nil {
			c.emit(code.OP_NULL)
		} else {
			err := c.Compile(node.Default)
			if err != nil {
				return err
			}

			if c.lastInstructionIs(code.OP_POP) {
				c.removeLastPop()
			}
		}
		
		// fill in position passed if block
		p := len(c.currentInstructions())
		for i := 0; i < len(jmpPos); i++ {
			c.changeOperand(jmpPos[i], p)
		}
		
	case *ast.TBlockStatement:
		for _, s := range node.Statements {
			err := c.Compile(s)
			if err != nil {
				return err
			}
		}
	
	case *ast.TFunctionStatement:
		err := c.Compile(node.FunctionLiteral)
		if err != nil {
			return err
		}
		symbol := c.symbolTable.Define(node.Name.String())
		if symbol.Scope == GlobalScope {
			c.emit(code.OP_DEFGL, symbol.Index)
		} else {
			c.emit(code.OP_DEFLC, symbol.Index)
		}
				
	case *ast.TFunctionLiteral:
		var rt object.IObject
		
		obj, ok := c.getBuiltinType(node.ReturnType)
		if !ok {
			return obj
		} else {
			rt = obj
		}
				
		c.enterScope()
		
		for key, _ := range node.Parameters {
			name := key.(*ast.TIdentifier)
			c.symbolTable.Define(name.Value)
		}
		
		err := c.Compile(node.Body)
		if err != nil {
			return err
		}
		
		if !c.lastInstructionIs(code.OP_RETV) {
			c.emit(code.OP_RETV)
		}
				
		lc := c.symbolTable.numDefinitions
		in := c.leaveScope()
		
		cf := &object.TCompiledFunction {
			Instructions: in,
			NumLocals: lc,
			NumParameters: len(node.Parameters),
			ReturnType: rt,
		}
		c.emit(code.OP_CONST, c.addConstant(cf))
		
	case *ast.TReturnStatement:
		err := c.Compile(node.ReturnValue)
		if err != nil {
			return err
		}
		c.emit(code.OP_RETV)
	
	case *ast.TSubLiteral:
		c.enterScope()
		
		for key, _ := range node.Parameters {
			name := key.(*ast.TIdentifier)
			c.symbolTable.Define(name.Value)
		}
		
		err := c.Compile(node.Body)
		if err != nil {
			return err
		}
		
		c.emit(code.OP_RET)
				
		lc := c.symbolTable.numDefinitions
		in := c.leaveScope()
		
		cs := &object.TCompiledSub {
			Instructions: in,
			NumLocals: lc,
			NumParameters: len(node.Parameters),
		}
		c.emit(code.OP_CONST, c.addConstant(cs))
		
	case *ast.TCallExpression:
		err := c.Compile(node.Routine)
		if err != nil {
			return err
		}
		for _, a := range node.Arguments {
			err := c.Compile(a)
			if err != nil {
				return err
			}
		}
		c.emit(code.OP_CALL, len(node.Arguments))
		
	}
	
	return nil
}
