// This is effectively a library for .NET remoting functionality
// The exploit remoting service tool by tyranid was the primary
// Note: Everything is in little endian

// Usage Example:
//     data = "\x00\x00blahblah"
//     uri = "tcp://192.168.113.231:9999/SomeEndpoint"
//     // conn = get a net.Conn somehow...
//     newmessage := dotnetremoting.Message{}
//     newmessage.WriteDefaultPreamble(dotnetremoting.OperationRequest, len(data), uri)
//     _,err = conn.Write([]byte(newmessage.GetMessage(data))) // NOTE THE GetMessage call here, this finalizes the message
//     if err != nil {
//         fmt.Println(fmt.Sprintf("Error sending: %s", err))
//         return
//     }
//     fmt.Println("sent")
//
//     // recv from end
//     buf := make ([]byte, 4096)
//     _,err = conn.Read(buf)
//     fmt.Println(fmt.Sprintf("%x", buf))
//     fmt.Println(fmt.Sprintf("%s", buf))

package dotnetremoting

import (
	"encoding/binary"
	"net"
	"net/url"

	"github.com/vulncheck-oss/go-exploit/output"
	"github.com/vulncheck-oss/go-exploit/transform"
)

// types and 'enums'.
type Message struct {
	PreambleData string
	HeaderData   string
}

// Strings, because go thinks strings are as nifty as bytes.
type MessageResponse struct {
	MajorVersion        string
	MinorVersion        string
	OperationType       string
	ContentDistribution string
	DataLength          int
	Headers             map[string]string
	Data                string
}

type OperationType string

const ( // OPERATION TYPES (ushort)
	OperationTypeRequest       OperationType = "\x00\x00"
	OperationTypeOneWayRequest OperationType = "\x01\x00"
	OperationTypeReply         OperationType = "\x02\x00"
)

type HeaderToken string

const ( // HEADER TOKENS (ushort)
	HeaderTokenEndHeaders      HeaderToken = "\x00\x00"
	HeaderTokenCustom          HeaderToken = "\x01\x00"
	HeaderTokenStatusCode      HeaderToken = "\x02\x00"
	HeaderTokenStatusPhrase    HeaderToken = "\x03\x00"
	HeaderTokenRequestURI      HeaderToken = "\x04\x00"
	HeaderTokenCloseConnection HeaderToken = "\x05\x00"
	HeaderTokenContentType     HeaderToken = "\x06\x00"
)

type HeaderDataFormat string

const ( // HEADER DATA FORMAT (byte)
	HeaderDataFormatVoid          HeaderDataFormat = "\x00"
	HeaderDataFormatCountedString HeaderDataFormat = "\x01"
	HeaderDataFormatByte          HeaderDataFormat = "\x02"
	HeaderDataFormatUint16        HeaderDataFormat = "\x03"
	HeaderDataFormatInt32         HeaderDataFormat = "\x04"
)

type ContentDistribution string

const ( // CONTENT DISTRIBUTION (ushort)
	ContentDistributionNotChunked ContentDistribution = "\x00\x00"
	ContentDistributionChunked    ContentDistribution = "\x01\x00"
)

type StringEncoding string

const ( // STRING ENCODING (byte)
	StringEncodingUnicode StringEncoding = "\x00"
	StringEncodingUtf8    StringEncoding = "\x01"
)

type TCPStatusCode string

const ( // TCP STATUS CODE (byte)
	TCPStatusCodeSuccess TCPStatusCode = "\x00"
	TCPStatusCodeError   TCPStatusCode = "\x01"
)

// The 'preamble' is basically the set of headers before the body.
func (msg *Message) WritePreamble(uri string, opType OperationType, dataLength int, contentDistribution ContentDistribution, contentType string) {
	uriObj, err := url.Parse(uri)
	if err != nil {
		output.PrintfFrameworkError("Could not write preamble: error trying to parse provided uri=%s, err=%s", uri, err)

		return
	}
	msg.PreambleData = ".NET"
	msg.PreambleData += "\x01"                                // major version
	msg.PreambleData += "\x00"                                // minor version
	msg.PreambleData += string(opType)                        // operation type
	msg.PreambleData += string(contentDistribution)           // content distribution
	msg.PreambleData += transform.PackLittleInt32(dataLength) // length of payload to be sent
	msg.AddContentTypeHeader(contentType)
	if uri != "" {
		msg.AddURIHeader(uri, HeaderTokenRequestURI)
		msg.AddCustomHeader("__RequestUri", uriObj.Path)
	}
}

// Can be used for 'most' things, otherwise use WritePreamble.
func (msg *Message) WriteDefaultPreamble(uri string, opType OperationType, dataLength int) {
	msg.WritePreamble(uri, opType, dataLength, ContentDistributionNotChunked, "application/octet-stream")
}

// Run this when you are finished putting headers and such together.
// This function will also add the end header so you should not write that part anywhere else.
// Obviously pass "" as the arg if you do not have data to add.
func (msg *Message) GetMessage(data string) string {
	return msg.PreambleData + msg.HeaderData + string(HeaderTokenEndHeaders) + data
}

// uri in this case should probably look something like tcp://1.2.3.4:2814/SomeEndpoint
func (msg *Message) AddCustomHeader(headerName string, headerValue string) {
	msg.HeaderData += string(HeaderTokenCustom)
	addCountedString(&msg.HeaderData, StringEncodingUtf8, headerName)
	addCountedString(&msg.HeaderData, StringEncodingUtf8, headerValue)
}

// uri in this case should probably look something like tcp://1.2.3.4:2814/SomeEndpoint
func (msg *Message) AddURIHeader(uri string, headerToken HeaderToken) {
	msg.HeaderData += string(headerToken)
	msg.HeaderData += string(HeaderDataFormatCountedString)
	addCountedString(&msg.HeaderData, StringEncodingUtf8, uri)
}

// this will probably be application/octet-stream almost every time but making options.
func (msg *Message) AddContentTypeHeader(contentType string) {
	msg.HeaderData += string(HeaderTokenContentType)
	msg.HeaderData += string(HeaderDataFormatCountedString)
	addCountedString(&msg.HeaderData, StringEncodingUtf8, contentType)
}

func (msg *Message) AddStatusPhraseHeader(statusPhrase string) { // untested
	msg.HeaderData += string(HeaderTokenStatusPhrase)
	addCountedString(&msg.HeaderData, StringEncodingUtf8, statusPhrase)
}

func (msg *Message) AddCloseConnectionHeader() { // untested
	msg.HeaderData += string(HeaderTokenCloseConnection)
	msg.HeaderData += string(HeaderDataFormatVoid)
}

func (msg *Message) AddStatusCodeHeader(isError bool) { // untested
	msg.HeaderData += string(HeaderTokenStatusCode)
	if isError {
		msg.HeaderData += transform.PackLittleInt16(1)

		return
	}
	// success
	msg.HeaderData += transform.PackLittleInt16(0)
}

func (msg *MessageResponse) Dump() {
	output.PrintFrameworkStatus("Contents of message:")
	output.PrintfFrameworkStatus("Major Version: %s", msg.MajorVersion)
	output.PrintfFrameworkStatus("Minor Version: %s", msg.MinorVersion)
	output.PrintfFrameworkStatus("Operation Type Version: %s", msg.OperationType)
	output.PrintfFrameworkStatus("ContentDistribution: %s", msg.ContentDistribution)
	if len(msg.Headers) > 0 {
		output.PrintStatus("Headers:")
		for key, val := range msg.Headers {
			output.PrintfFrameworkStatus("Header Key: %s | Header Value: %s", key, val)
		}
	}
	output.PrintfFrameworkStatus("DataLength: %d", msg.DataLength)
	output.PrintfFrameworkStatus("Data: %s", msg.Data)
}

// Parsing functions.
func ParseResponseFromConn(conn net.Conn) (MessageResponse, bool) {
	msg := MessageResponse{}
	magicBuf := make([]byte, 4)
	majorVerBuf := make([]byte, 1)
	minorVerBuf := make([]byte, 1)
	opTypeBuf := make([]byte, 2)
	contentDistributionBuf := make([]byte, 2)
	dataLengthBuf := make([]byte, 4)

	// checking magic bytes from message
	_, err := conn.Read(magicBuf)
	if err != nil {
		output.PrintFrameworkError("Could not parse magic from response")

		return MessageResponse{}, false
	}

	if string(magicBuf) != ".NET" {
		output.PrintfFrameworkError("Magic mismatch: received: %s/%x, expected: '.NET'", string(magicBuf), magicBuf)

		return MessageResponse{}, false
	}

	// finish parsing preamble
	_, err = conn.Read(majorVerBuf)
	if err != nil {
		output.PrintFrameworkError("Could not parse major version from response")

		return MessageResponse{}, false
	}
	msg.MajorVersion = string(majorVerBuf)

	_, err = conn.Read(minorVerBuf)
	if err != nil {
		output.PrintFrameworkError("Could not parse minor version from response")

		return MessageResponse{}, false
	}
	msg.MinorVersion = string(minorVerBuf)

	_, err = conn.Read(opTypeBuf)
	if err != nil {
		output.PrintFrameworkError("Could not parse operation type from response")

		return MessageResponse{}, false
	}
	msg.OperationType = string(opTypeBuf)

	_, err = conn.Read(contentDistributionBuf)
	if err != nil {
		output.PrintFrameworkError("Could not parse content distribution from response")

		return MessageResponse{}, false
	}
	msg.ContentDistribution = string(contentDistributionBuf)

	_, err = conn.Read(dataLengthBuf)
	if err != nil {
		output.PrintFrameworkError("Could not parse data length from response")

		return MessageResponse{}, false
	}
	msg.DataLength = int(binary.LittleEndian.Uint32(dataLengthBuf))

	// take care of the headers
	headers, ok := readHeadersFromConn(conn)
	if !ok {
		output.PrintFrameworkError("Failed parsing headers from response")

		return MessageResponse{}, false
	}

	msg.Headers = headers

	msg.Data, ok = readNBytes(conn, msg.DataLength)
	if !ok {
		output.PrintFrameworkError("Failed reading data from response")

		return MessageResponse{}, false
	}

	return msg, true
}

//nolint:gocognit
func readHeadersFromConn(conn net.Conn) (map[string]string, bool) {
	readHeaders := make(map[string]string)
	tokenBuf := make([]byte, 2)
	dataTypeBuf := make([]byte, 1)

	// read initial token
	_, err := conn.Read(tokenBuf)
	if err != nil {
		output.PrintFrameworkError("Failed reading initial token value from response")

		return map[string]string{}, false
	}

	// while we have not read the End of Headers 'token'
	for string(tokenBuf) != string(HeaderTokenEndHeaders) {
		name := string(tokenBuf)
		value := ""

		switch string(tokenBuf) {
		case string(HeaderTokenCustom): // HeaderTokenCustom
			// untested
			str, ok := readHeaderStringFromConn(conn)
			if !ok {
				output.PrintFrameworkError("Failed reading custom header name from response")

				return map[string]string{}, false
			}
			name = str

			str, ok = readHeaderStringFromConn(conn)
			if !ok {
				output.PrintFrameworkError("Failed reading custom header value from response")

				return map[string]string{}, false
			}
			value = str

		default:
			_, err := conn.Read(dataTypeBuf)
			if err != nil {
				output.PrintfFrameworkError("Failed reading data type, err=%s", err)

				return map[string]string{}, false
			}

			switch string(dataTypeBuf) {
			case string(HeaderDataFormatVoid):
				break
			case string(HeaderDataFormatCountedString):
				data, ok := readHeaderStringFromConn(conn)
				if !ok {
					output.PrintFrameworkError("Failed reading counted header string")

					return map[string]string{}, false
				}
				value = data
			case string(HeaderDataFormatByte):
				dataBuf := make([]byte, 1)
				_, err = conn.Read(dataBuf)
				if err != nil {
					output.PrintfFrameworkError("Failed reading format byte, err=%s", err)

					return map[string]string{}, false
				}
				value = string(dataBuf)
			case string(HeaderDataFormatUint16):
				dataBuf := make([]byte, 2)
				_, err = conn.Read(dataBuf)
				if err != nil {
					output.PrintfFrameworkError("Failed reading uint16, err=%s", err)

					return map[string]string{}, false
				}
				value = string(dataBuf)
			case string(HeaderDataFormatInt32):
				dataBuf := make([]byte, 4)
				_, err = conn.Read(dataBuf)
				if err != nil {
					output.PrintfFrameworkError("Failed reading uint32, err=%s", err)

					return map[string]string{}, false
				}
				value = string(dataBuf)
			}
		}

		output.PrintfFrameworkTrace("Parsed header: %s = %s", name, value)
		readHeaders[name] = value

		_, err = conn.Read(tokenBuf)
		if err != nil {
			output.PrintfFrameworkError("Failed reading token value, err=%s", err)

			return map[string]string{}, false
		}
		output.PrintfFrameworkTrace("token value %x", tokenBuf)
	}
	output.PrintFrameworkTrace("done parsing headers")

	return readHeaders, true
}

func readHeaderStringFromConn(conn net.Conn) (string, bool) {
	encodingTypeBuf := make([]byte, 1)
	stringLengthBuf := make([]byte, 4)
	_, err := conn.Read(encodingTypeBuf)
	if err != nil {
		output.PrintfFrameworkError("Failed reading encoding type from header string")

		return "", false
	}
	_, err = conn.Read(stringLengthBuf)
	if err != nil {
		output.PrintfFrameworkError("Failed reading string length from header string")

		return "", false
	}

	// encodingType := string(encodingTypeBuf) // sorry, just going to ignore this for now.
	stringLength := int(binary.LittleEndian.Uint32(stringLengthBuf))

	stringData, ok := readNBytes(conn, stringLength)
	if !ok {
		return "", false
	}

	return stringData, true
}

// don't love this function.
func readNBytes(conn net.Conn, n int) (string, bool) {
	data := ""
	buf := make([]byte, 1) // ugh...
	remaining := n

	for remaining > 0 {
		bytesRead, err := conn.Read(buf)
		if err != nil {
			output.PrintfFrameworkError("Failed reading N bytes from connection")

			return "", false
		}
		data += string(buf)

		remaining -= bytesRead
	}

	return data, true
}

// Helper functions.
// This is private on purpose to promote helper functions to keep things easier to use.
//
//nolint:unparam
func addCountedString(msg *string, encodingType StringEncoding, stringValue string) {
	*msg += string(encodingType)
	*msg += transform.PackLittleInt32(len(stringValue))
	*msg += stringValue
}
