#!/bin/bash

me="${0##*/}"

# convert number of mask bits to x.x.x.x or x:x:x:x:x:x:x:x mask format
cidr2mask() {
	local i mask=""
	local full_octets=$(($1/8))
	local partial_octet=$(($1%8))

	for ((i=0;i<4;i+=1)); do
		if [ $i -lt $full_octets ]; then
			mask+=255
		elif [ $i -eq $full_octets ]; then
			mask+=$((256 - 2**(8-$partial_octet)))
		else
			mask+=0
		fi
		test $i -lt 3 && mask+=.
	done
	echo $mask
}

cidr2maskipv6() {
	local mask=""
	local num_bits=$1

	if [ $num_bits -le 128 ]; then
		local full_blocks=$((num_bits / 16))
		local remaining_bits=$((16 - (num_bits % 16)))

		for ((i = 0; i < 8; i++)); do
			if [ $i -lt $full_blocks ]; then
				mask+="ffff"
			elif [ $i -eq $full_blocks ]; then
				mask+="$(printf "%x" $((0xFFFF >> $remaining_bits)))"
			else
				mask+="0"
			fi
			[ $i -lt 7 ] && mask+=":"
		done
	else
		echo "Invalid prefix length for IPv6"
		return 1
	fi
	echo $mask
}

# apply netmask (second argument) to ip address (first argument)
netcalc() {
	local ipa=$(echo ${1} | awk -F. '{ print $1 }')
	local ipb=$(echo ${1} | awk -F. '{ print $2 }')
	local ipc=$(echo ${1} | awk -F. '{ print $3 }')
	local ipd=$(echo ${1} | awk -F. '{ print $4 }')
	local mka=$(echo ${2} | awk -F. '{ print $1 }')
	local mkb=$(echo ${2} | awk -F. '{ print $2 }')
	local mkc=$(echo ${2} | awk -F. '{ print $3 }')
	local mkd=$(echo ${2} | awk -F. '{ print $4 }')
	local nta="$(( $ipa & $mka ))"
	local ntb="$(( $ipb & $mkb ))"
	local ntc="$(( $ipc & $mkc ))"
	local ntd="$(( $ipd & $mkd ))"
	echo "$nta.$ntb.$ntc.$ntd"
}

# expand IPv6 address to full length
expand_ipv6() {
	local ip=$1
	local expanded_ip=""

	# Split the IP address into segments
	IFS=':' read -r -a segments <<< "$ip"

	# Count the number of segments
	local num_segments=${#segments[@]}

	# Check if "::" is present in the IP address
	local double_colon_index=$(echo "$ip" | grep -o "::" | wc -l)
	local double_colon_present=false

	# If "::" is present, count the number of segments before and after it
	if [[ "$double_colon_index" -gt 0 ]]; then
		double_colon_present=true
		local before_double_colon=${segments[0]}
		local after_double_colon=${segments[-1]}

		# Count the number of segments before "::"
		local num_before_double_colon=$(echo "$before_double_colon" | grep -o ":" | wc -l)
		[[ -z "$before_double_colon" ]] && num_before_double_colon=0

		# Count the number of segments after "::"
		local num_after_double_colon=$(echo "$after_double_colon" | grep -o ":" | wc -l)
		[[ -z "$after_double_colon" ]] && num_after_double_colon=0
	fi

	# Iterate over each segment
	for segment in "${segments[@]}"; do
		# If "::" is present, handle segments before and after it
		if [[ "$double_colon_present" = true ]]; then
			if [[ "$segment" = "" ]]; then
				# Fill in the missing segments with "0000"
				local missing_segments=$((8 - num_segments + 1))
				for ((i=1; i<=$missing_segments; i++)); do
					expanded_ip+="0000:"
				done
			else
				# Expand the segment to 4 characters and append it to the expanded IP
				expanded_ip+=$(printf "%04s" "$segment" | tr ' ' '0')
				expanded_ip+=":"
			fi
		else
			# Expand the segment to 4 characters and append it to the expanded IP
			expanded_ip+=$(printf "%04s" "$segment" | tr ' ' '0')
			expanded_ip+=":"
		fi
	done

	# Remove the trailing colon
	expanded_ip="${expanded_ip%:}"

	echo "$expanded_ip"
}

netcalcipv6() {
	local ip=$1
	local mask=$2

	# Expand IP address and subnet mask to full length
	ip=$(expand_ipv6 $ip)

	local ip_blocks=(${ip//:/ })
	local mask_blocks=(${mask//:/ })
	local result=""

	for ((i = 0; i < 8; i++)); do
		local dec_ip_block=$((16#${ip_blocks[i]}))
		local dec_mask_block=$((16#${mask_blocks[i]}))
		local network_block=$((dec_ip_block & dec_mask_block))
		result+=$(printf "%04x" $network_block)
		[ $i -lt 7 ] && result+=":"
	done
	echo $result
}

# Check if the user wants to skip setting the routes
checkskipcmd=$(cat /sys/module/ksocklnd/parameters/skip_mr_route_setup 2>&-)
if [ "$checkskipcmd" == "1" ]; then
	exit 0
fi

# Extract comma-separated interfaces from the argument
j=0
declare -a interfaces
for i in $(echo $1 | sed "s/,/ /g")
do
	# verify that the interface exists
	ipv4_addr=$(/sbin/ip -o -4 addr list $i 2>&- | awk '{print $4}' | cut -d/ -f1)
	ipv6_addr=$(/sbin/ip -o -6 addr list $i 2>&- | awk '{print $4}' | cut -d/ -f1)

	if [ -z "$ipv4_addr" ] && [ -z "$ipv6_addr" ]; then
		# No IPv4 or IPv6 address configured on this interface, skip it
		logcmd=(logger "${me}: skip setting up route for ${i}: IP address not found")
		eval "${logcmd[@]}"
		continue
	fi

	# Check if route is already set up for this interface (IPv4 or IPv6)
	if [ ! -z "$ipv4_addr" ]; then
		intfroute_ipv4=$(/sbin/ip -o -4 route show table $i 2>&-)
		if [ ! -z "$intfroute_ipv4" ]; then
			echo $intfroute_ipv4
			# IPv4 route exists, skip this interface
			logcmd=(logger "${me}: skip setting up route for ${i}: IPv4 route exists")
			eval "${logcmd[@]}"
			continue
		fi
	fi

	if [ ! -z "$ipv6_addr" ]; then
		intfroute_ipv6=$(/sbin/ip -o -6 route show table $i 2>&-)
		if [ ! -z "$intfroute_ipv6" ]; then
			# IPv6 route exists, skip this interface
			logcmd=(logger "${me}: skip setting up route for ${i}: IPv6 route exists")
			eval "${logcmd[@]}"
			continue
		fi
	fi

	interfaces[$j]=$i
	j=$((j+1))
done

# this array will contain the interfaces
# already listed in rt_tables
interfaces_listed=()

# flush cache for every interface
for i in "${interfaces[@]}"
do
	# build command
	redirect="2>&-"
	flushcmd=(/sbin/ip route flush table ${i} ${redirect} )
	# execute command
	eval "${flushcmd[@]}"
	logcmd=(logger "${me}: ${flushcmd[@]}")
	eval "${logcmd[@]}"
done

filename='/etc/iproute2/rt_tables'
n=1
max_table_num=0
while read line; do
	# reading each line
	# trim leading and trailing spaces
	line=`echo $line | sed -e 's/^[[:space:]]*//'`
	linelen=$(echo -n $line | wc -m)
	# don't check empty lines
	if [ $linelen -lt 1 ]; then
		continue
	fi
	# don't check comments
	if [[ ${line:0:1} == "#" ]]; then
		continue
	fi
	# split using space as separator
	splitline=( $line )
	# check the table number and update the max
	if [ $max_table_num -lt ${splitline[0]} ]; then
		max_table_num=${splitline[0]}
	fi
	# check if any of the interfaces are listed
	for i in "${interfaces[@]}"
	do
		if [[ " ${splitline[@]} " =~ " ${i} " ]]; then
			if [[ " ${interfaces[@]} " =~ " ${i} " ]]; then
				interfaces_listed+=($i)
			fi
		fi
	done
	n=$((n+1))
done < $filename

# add entries for unlisted interfaces
for i in "${interfaces[@]}"
do
	if [[ ! " ${interfaces_listed[@]} " =~ " ${i} " ]]; then
		max_table_num=$((max_table_num+1))
		echo "$max_table_num $i" >> $filename
	fi
done

# generate list of available not-dev-specific default gateways
gwsline=$(/sbin/ip route | awk '/default/ { if ($0 !~ /dev/) print $3 }')
comm_gateways=($gwsline)

gwsline_ipv6=$(/sbin/ip -6 route | awk '/default/ { if ($0 !~ /dev/) print $3 }')
comm_gateways_ipv6=($gwsline_ipv6)

# Ping the gw to check if it is alive
pinggw() {
	local gw=$1
	local timeout=1  # Set the timeout for the ping command

	# Determine if the IP is IPv4 or IPv6
	if [[ $gw =~ .*:.* ]]; then
		# IPv6: Use ping6
		if ping6 -c 1 -W $timeout "$gw" &> /dev/null; then
			return 0  # reachable
		else
			logcmd=(logger "${me}: unreachable default gateway ${gw}: skipping ")
	                eval "${logcmd[@]}"
			return 1  # not reachable
		fi
	else
		# IPv4: Use the standard ping
		if ping -c 1 -W $timeout "$gw" &> /dev/null; then
			return 0  # reachable
		else
			logcmd=(logger "${me}: unreachable default gateway ${gw}: skipping ")
			eval "${logcmd[@]}"
			return 1  # not reachable
		fi
	fi
}

# Select a gateway on the same subnet for both IPv4 and IPv6
selectgw() {
	local ip=$1
	local mask=$2
	local interface=$3

	# Check if the IP address is IPv4 or IPv6
	if [[ $ip =~ .*:.* ]]; then
		# IPv6
		ip=$(expand_ipv6 $ip)
		mask=$(expand_ipv6 $mask)
		local ip_blocks=(${ip//:/ })
		local mask_blocks=(${mask//:/ })
		local result=""

		for ((i = 0; i < 8; i++)); do
			local dec_ip_block=$((16#${ip_blocks[i]}))
			local dec_mask_block=$((16#${mask_blocks[i]}))
			local network_block=$((dec_ip_block & dec_mask_block))
			result+=$(printf "%04x" $network_block)
			[ $i -lt 7 ] && result+=":"
		done

		local network_ipv6=$result
		local spec_gateways_ipv6=($(/sbin/ip -6 route show dev $interface | awk '/default/ { print $3 }'))

		for gw in "${spec_gateways_ipv6[@]}"; do
			gw_network=$(netcalcipv6 "$gw" "$mask")
			if [[ "$network_ipv6" == "$gw_network" ]]; then
				if pinggw "$gw"; then
					echo $gw
					return
				fi
			fi
		done
		for gw in "${comm_gateways_ipv6[@]}"; do
			gw_network=$(netcalcipv6 "$gw" "$mask")
			if [[ "$network_ipv6" == "$gw_network" ]]; then
				if pinggw "$gw"; then
					echo $gw
					return
				fi
			fi
		done
	else
		# IPv4
		local ip_parts=(${ip//./ })
		local mask_parts=(${mask//./ })
		local network_ipv4=""

		for ((i = 0; i < 4; i++)); do
			local network_part=$((ip_parts[i] & mask_parts[i]))
			network_ipv4+="$network_part"
			[ $i -lt 3 ] && network_ipv4+="."
		done

		local spec_gateways=($(/sbin/ip route show dev $interface | awk '/default/ { print $3 }'))

		for gw in "${spec_gateways[@]}"; do
			gw_network=$(netcalc "$gw" "$mask")
			if [[ "$network_ipv4" == "$gw_network" ]]; then
				if pinggw "$gw"; then
					echo $gw
					return
				fi
			fi
		done
		for gw in "${comm_gateways[@]}"; do
			gw_network=$(netcalc "$gw" "$mask")
			if [[ "$network_ipv4" == "$gw_network" ]]; then
				if pinggw "$gw"; then
					echo $gw
					return
				fi
			fi
		done
	fi
	echo "0.0.0.0"
}

# Add the routing entries and rules for IPv4 and/or IPv6
for i in "${interfaces[@]}"
do
	# Extract IPv4 and IPv6 addresses and netmasks in CIDR format
	addr_ipv4=($(/sbin/ip -o -4 addr list $i 2>&- | awk '{print $4}' | cut -d/ -f1))
	cidrmask_ipv4=($(/sbin/ip -o -4 addr list $i 2>&- | awk '{print $4}' | cut -d/ -f2))
	addr_ipv6=($(/sbin/ip -o -6 addr list $i 2>&- | awk '{print $4}' | cut -d/ -f1))
	cidrmask_ipv6=($(/sbin/ip -o -6 addr list $i 2>&- | awk '{print $4}' | cut -d/ -f2))
	# Configure routing and rules for IPv4 (if IPv4 address is configured)
	if [ ! -z "${addr_ipv4}" ]; then
		# Convert CIDR mask to mask in dot format for IPv4
		dotmask_ipv4=$(cidr2mask ${cidrmask_ipv4[0]})
		# Find a gateway on the same subnet for IPv4
		gw_ipv4=$(selectgw "${addr_ipv4[0]}" "$dotmask_ipv4" "$i")
		# Build and execute route commands for IPv4
		if [[ $gw_ipv4 == "0.0.0.0" ]]; then
			# Gateway not found, assume local destinations for IPv4
			net_ipv4=$(netcalc "${addr_ipv4[0]}" "$dotmask_ipv4")
			routecmd_ipv4=(/sbin/ip route add ${net_ipv4}/${cidrmask_ipv4[0]} dev ${i} proto kernel scope link src ${addr_ipv4[0]} table ${i})
		else
			routecmd_ipv4=(/sbin/ip route add default via ${gw_ipv4} dev ${i} table ${i})
		fi
		ruledelcmd_ipv4=(/sbin/ip rule del from ${addr_ipv4[0]} table ${i} '&>/dev/null')
		ruleaddcmd_ipv4=(/sbin/ip rule add from ${addr_ipv4[0]} table ${i})

		routeerr_ipv4=$(eval "${routecmd_ipv4[@]}" 2>&1 >/dev/null)
		ruledelerr_ipv4=$(eval "${ruledelcmd_ipv4[@]}" 2>&1 >/dev/null)
		ruleadderr_ipv4=$(eval "${ruleaddcmd_ipv4[@]}" 2>&1 >/dev/null)

		logcmd1_ipv4=(logger "${me}: ${routecmd_ipv4[@]} ${routeerr_ipv4}")
		logcmd2_ipv4=(logger "${me}: ${ruledelcmd_ipv4[@]} ${ruledelerr_ipv4}")
		logcmd3_ipv4=(logger "${me}: ${ruleaddcmd_ipv4[@]} ${ruleadderr_ipv4}")

		eval "${logcmd1_ipv4[@]}"
		eval "${logcmd2_ipv4[@]}"
		eval "${logcmd3_ipv4[@]}"
	fi

	# Configure routing and rules for IPv6 (if IPv6 address is configured)
	if [ ! -z "${addr_ipv6}" ]; then
		# Convert CIDR mask to mask in dot format for IPv6
		dotmask_ipv6=$(cidr2maskipv6 ${cidrmask_ipv6[0]})
		# Find a gateway on the same subnet for IPv6
		gw_ipv6=$(selectgw "${addr_ipv6[0]}" "$dotmask_ipv6" "$i")
		# Build and execute route commands for IPv6
		if [[ $gw_ipv6 == "0.0.0.0" ]]; then
			# Gateway not found, assume local destinations for IPv6
			net_ipv6=$(netcalcipv6 "${addr_ipv6[0]}" "$dotmask_ipv6")
			routecmd_ipv6=(/sbin/ip -6 route add ${addr_ipv6[0]}/${cidrmask_ipv6[0]} dev ${i} proto kernel scope link src ${addr_ipv6[0]} table ${i})
		else
			routecmd_ipv6=(/sbin/ip -6 route add default via ${gw_ipv6} dev ${i} table ${i})
		fi
		ruledelcmd_ipv6=(/sbin/ip -6 rule del from ${addr_ipv6[0]} table ${i} '&>/dev/null')
		ruleaddcmd_ipv6=(/sbin/ip -6 rule add from ${addr_ipv6[0]} table ${i})

		routeerr_ipv6=$(eval "${routecmd_ipv6[@]}" 2>&1 >/dev/null)
		ruledelerr_ipv6=$(eval "${ruledelcmd_ipv6[@]}" 2>&1 >/dev/null)
		ruleadderr_ipv6=$(eval "${ruleaddcmd_ipv6[@]}" 2>&1 >/dev/null)

		logcmd1_ipv6=(logger -- "${me}: ${routecmd_ipv6[@]} ${routeerr_ipv6}")
		logcmd2_ipv6=(logger -- "${me}: ${ruledelcmd_ipv6[@]} ${ruledelerr_ipv6}")
		logcmd3_ipv6=(logger -- "${me}: ${ruleaddcmd_ipv6[@]} ${ruleadderr_ipv6}")

		eval "${logcmd1_ipv6[@]}"
		eval "${logcmd2_ipv6[@]}"
		eval "${logcmd3_ipv6[@]}"
	fi
done

# flush arp tables
for i in "${interfaces[@]}"
do
	flushcmd=(/sbin/ip neigh flush dev ${i})
	eval ${flushcmd[@]}
done

