#!/bin/bash
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# For a single-queue / no MSI-X virtionet device, sets the IRQ affinities to
# processor 0. For this virtionet configuration, distributing IRQs to all
# processors results in comparatively high cpu utilization and comparatively

# low network bandwidth.
#
# For a multi-queue / MSI-X virtionet device, sets the IRQ affinities to the
# per-IRQ affinity hint. The virtionet driver maps each virtionet TX (RX) queue
# MSI-X interrupt to a unique single CPU if the number of TX (RX) queues equals
# the number of online CPUs. The mapping of network MSI-X interrupt vector to
# CPUs is stored in the virtionet MSI-X interrupt vector affinity hint. This
# configuration allows network traffic to be spread across the CPUs, giving
# each CPU a dedicated TX and RX network queue, while ensuring that all packets
# from a single flow are delivered to the same CPU.
#
# For a gvnic device, set the IRQ affinities to the per-IRQ affinity hint.
# The google virtual ethernet driver maps each queue MSI-X interrupt to a
# unique single CPU, which is stored in the affinity_hint for each MSI-X
# vector. In older versions of the kernel, irqblanace is expected to copy the
# affinity_hint to smp_affinity; however, GCE instances disable irqbalance by
# default. This script copies over the affinity_hint to smp_affinity on boot to
# replicate the behavior of irqbalance.

# Allow changing the metadata server and "root" for testing.
METADATA_SERVER="169.254.169.254"
ROOT_DIR="/"
if [[ ! -z "${GOOGLE_SET_MULTIQUEUE_METADATASERVER+x}" ]]; then
  METADATA_SERVER="$GOOGLE_SET_MULTIQUEUE_METADATASERVER"
fi
if [[ ! -z "${GOOGLE_SET_MULTIQUEUE_FAKEROOT+x}" ]]; then
  ROOT_DIR="$GOOGLE_SET_MULTIQUEUE_FAKEROOT"
fi

function get_metadata() {
  local path="$1"
  curl -s -m 1 -H "Metadata-Flavor: Google" "http://${METADATA_SERVER}/computeMetadata/v1/${path}"
}

A4X_ALL_CPUS_MASK="0000ffff,ffffffff,ffffffff,ffffffff,ffffffff"
A4X_RX_RING_LENGTH="2048"
A4X_TX_RING_LENGTH="1024"

function is_decimal_int() {
  [ "${1}" -eq "${1}" ] > /dev/null 2>&1
}

function set_channels() {
  ethtool -L "${1}" combined "${2}" > /dev/null 2>&1
}

function set_irq_range_idpf() {
  local -r nic="$1"
  local bind_cores_index="$2"
  local irq_ranges=("${@:3}")

  # The user may not have this $nic configured on their VM, if not, just skip
  # it, no need to error out.
  if [ ! -d "${ROOT_DIR}sys/class/net/"$nic"/device" ]; then
    return;
  fi

  # More IRQs are in msi_irqs than actually exist for idpf.
  local alleged_irqs=($(ls ${ROOT_DIR}sys/class/net/"$nic"/device/msi_irqs))
  local -a actual_irqs=()
  for irq in "${alleged_irqs[@]}"; do
    if [[ -d "${ROOT_DIR}proc/irq/${irq}" ]]; then
      # Only reaffinitize TxRx irqs, not the Mailbox.
      if ls ${ROOT_DIR}proc/irq/${irq}/*TxRx* 1> /dev/null 2>&1; then
        actual_irqs+=("${irq}")
      fi
    fi
  done
  local num_irqs=${#actual_irqs[@]}
  for ((i=0; i<num_irqs; i+=1)); do
    core="${irq_ranges[${bind_cores_index}]}"
    echo "Setting irq binding for "$nic" to core ${actual_irqs[$i]} to "${core} >&2
    echo "${core}" > "${ROOT_DIR}proc/irq/${actual_irqs[$i]}/smp_affinity_list"
    bind_cores_index=$((bind_cores_index + 1))
  done
}

function set_irq_range_gve() {
  local -r nic="$1"
  local bind_cores_index="$2"
  local irq_ranges=("${@:3}")

  # The user may not have this $nic configured on their VM, if not, just skip
  # it, no need to error out.
  if [ ! -d "${ROOT_DIR}sys/class/net/"$nic"/device" ]; then
    return;
  fi

  # We count the number of rx queues and assume number of rx queues == tx
  # queues. The number of queues shown in the sysfs stands for the initial
  # queues while the number of IRQs stands for the max queues. The number of
  # initial queues should be always less than or equal to that of the max
  # queues.
  num_irqs=$(( $(ls ${ROOT_DIR}sys/class/net/"$nic"/device/msi_irqs | wc -l) / 2 ))
  num_q=$(ls -1 ${ROOT_DIR}sys/class/net/"$nic"/queues/ | grep rx | wc -l)
  echo "Setting irq binding for "$nic" to core ["${irq_ranges[${bind_cores_index}]}" - "${irq_ranges[$((bind_cores_index + num_q - 1))]}] ... >&2

  irqs=($(ls ${ROOT_DIR}sys/class/net/"$nic"/device/msi_irqs | sort -g))
  for ((irq = 0; irq < "$num_irqs"; irq++)); do
    tx_irq=${irqs[$irq]}
    rx_irq=${irqs[$((irq + num_irqs))]}

    # Only allocate $num_q cores to the IRQs and queues. If the number of IRQs
    # is more than that of queues, the CPUs will be wrapped around.
    core="${irq_ranges[${bind_cores_index}]}"
    ((bind_cores_index++))

    # this is GVE's TX irq. See gve_tx_idx_to_ntfy().
    echo "$core" > "${ROOT_DIR}proc/irq/${tx_irq}/smp_affinity_list"
    echo "tx_irq: ${tx_irq}, assigned irq core: $(cat "${ROOT_DIR}proc/irq/${tx_irq}/smp_affinity_list")" >&2

    # this is GVE's RX irq. See gve_rx_idx_to_ntfy().
    echo "$core" > "${ROOT_DIR}proc/irq/${rx_irq}/smp_affinity_list"
    echo "rx_irq: ${rx_irq}, assigned irq core: $(cat "${ROOT_DIR}proc/irq/${rx_irq}/smp_affinity_list")" >&2

    # Check if the queue exists at present because the number of IRQs equals
    # the max number of queues allocated and could be greater than the current
    # number of queues.
    tx_queue=${ROOT_DIR}sys/class/net/"$nic"/queues/tx-"$irq"
    if ls $tx_queue 1> /dev/null 2>&1; then
      echo -en "$nic:q-$irq: \ttx: irq $tx_irq bind to $core \trx: irq $rx_irq bind to $core" >&2
      echo -e " \txps_cpus bind to $(cat $tx_queue/xps_cpus)" >&2
    else
      echo -e "$nic:q-$irq: \ttx: irq $tx_irq bind to $core \trx: irq $rx_irq bind to $core" >&2
    fi

  done

  echo "$bind_cores_index"
}

# returns 0 (success) if the all the interfaces contains pnic_id on the Metadats server.
function contains_pnic_ids() {
  local interfaces=$(get_metadata "instance/network-interfaces/")
  for interface in $interfaces; do
    echo "Interface: $interface"

    network_interfaces_mds_attributes=$(get_metadata "instance/network-interfaces/$interface/")

    if ! echo "$network_interfaces_mds_attributes" | grep -q "physical-nic-id"; then
      echo "physical-nic-id NOT found in interface $interface"
      return 1
    fi
  done

  return 0
}

# returns 0 (success) if the platform is a multinic accelerator platform.
function is_multinic_accelerator_platform() {
  contains_pnic_ids
  CONTAINS_PNIC_IDS=$?

  [[ $CONTAINS_PNIC_IDS -eq 0 \
  || "$machine_type" == *"a3-highgpu-8g"* \
  || "$machine_type" == *"a3-ultragpu-8g"* \
  || "$machine_type" == *"a3-megagpu-8g"* \
  || "$machine_type" == *"a3-edgegpu-8g"* \
  || "$machine_type" == *"a3-ultragpu-"* \
  || "$machine_type" == *"a4-highgpu-"* \
  || "$machine_type" == *"a4x-highgpu-"* \
  || "$machine_type" == *"a4x-maxgpu-"* \
  || "$machine_type" == *"c4x-"* ]] || return 1
  return 0
}

# returns 0 (success) if the supplied nic is a Gvnic device.
function is_gvnic() {
  local -r nic_name="$1"
  local -r driver_type=$(ethtool -i $nic_name | grep driver)

  [[ "$driver_type" == *"gve"*
  || "$driver_type" == *"gvnic"* ]] || return 1

  return 0
}

# returns 0 (success) if the supplied nic is an IDPF device.
function is_idpf() {
  local -r nic_name="$1"
  local -r driver_type=$(ethtool -i $nic_name | grep driver)

  [[ "$driver_type" == *"idpf"* ]] || return 1

  return 0
}

# Returns the CPU ranges for a given NUMA node.
# The CPU ranges will be returned as a space-separated list of start/end integers.
function get_vcpu_ranges {
  local numa_node="$1"
  local cpulist_file="${ROOT_DIR}sys/devices/system/node/node${numa_node}/cpulist"
  if [ -f "$cpulist_file" ]; then
    local cpulist=$(cat "$cpulist_file" | tr ',' ' ')
    local result=""
    for r in $cpulist; do
      local start=$(echo "$r" | cut -d '-' -f 1)
      local end=$(echo "$r" | cut -d '-' -f 2)
      [[ -z "$end" ]] && end=$start
      result+="$start $end "
    done
    echo "$result"
  fi
}

function unpack_cpu_ranges() {
    local input_ranges=($1)
    local -n irq_ranges="$2"
    for ((i=0; i<${#input_ranges[@]}; i+=2)); do
        local start="${input_ranges[$i]}"
        local end="${input_ranges[$((i+1))]}"

        for ((core_number=start; core_number<=end; core_number++)); do
            irq_ranges+=("$core_number")
        done
    done
}

# Converts a hexadecimal bitmap to rangelist
# ex. bitmap=00000000,00000000,00fff000,000003ff CPUs=0-9,44-55
function bitmap_to_rangelist() { # bitmap
  local bitmap="${1:-}" # must be non empty, only hex digits and commas

  [[ "${bitmap}" =~ ^[0-9a-fA-F,]+$ ]] || return 1
  bitmap="${bitmap//,/}"                # remove commas

  local comma='' ret=''
  local bit=0 l=-1 h=0 # current bit and range boundaries
  local i j # process one character at a time starting from right (low index)
  for ((i = ${#bitmap} - 1; i >= 0; i--)); do
    local cur="0x${bitmap:${i}:1}"
    for ((j = 0; j < 4; j++, bit++, cur >>= 1)); do
      (( cur & 1 )) || continue                             # bit is 0
      (( l < 0 )) && (( l = bit, h = bit, 1 )) && continue  # first bit
      (( bit == h + 1 )) && (( h = bit, 1 )) && continue    # extend range
      ret+="${comma}${l}" ; (( l != h )) && ret+="-${h}"    # add range
      (( l = bit, h = bit ))                                # start new interval
      comma=","
    done
  done
  (( l < 0 )) && return
  ret+="${comma}${l}"; (( l != h )) && ret+="-${h}"         # add final entry
  echo $ret
}

# Converts a list of CPUs to a hexadecimal bitmap
# ex. CPUs=[0,1,2,3,4,5,6,7,8,9,44,45,46,47,48,49,50,51,52,53,54,55]
# bitmap=00000000,00000000,00fff000,000003ff
function rangelist_to_bitmap() { # list highest_cpu
  local ranges="${1:-}" # can be empty, only digits, commas and dash
  local highest=${2:-1} # highest CPU
  [[ "${ranges}" =~ ^[0-9,-]*$ ]] || return 1
  ranges="${ranges//,/ }"       # replace comma with space

  local digits=()
  local range i l h
  for range in $ranges; do
    read l h <<< $(echo ${range/-/ })
    [[ -z "$h" ]] && h=l
    for ((i = l; i <= h; i++)) { (( digits[i / 4] |= 1 << (i & 3) )); }
    (( highest = h > highest ? h : highest))
  done

  # Print in reverse order with commas
  local ret="" hex="0123456789abcdef"
  (( h = (highest + 31) / 32 * 8 )) # make a multiple of 32 CPUs
  for (( i = h - 1; i >= 0; i--)) ; do
    ret+="${hex:${digits[$i]}:1}"
    (( i & 7 || i == 0)) || ret+=","
  done
  echo $ret
}

# Returns all the network interface names excluding "lo"
get_network_interfaces() {
    local network_interfaces=()

    for nic_dir in ${ROOT_DIR}sys/class/net/*; do
        local nic_name=$(basename "${nic_dir}")

        if [[ "${nic_name}" == "lo" || ! -e "${nic_dir}/device" ]]; then
            continue
        fi

        network_interfaces+=("${nic_name}")
    done

    echo "${network_interfaces[@]}"
}

# For XPS affinity configuration, we'd do the following assignment:
# 1. For each interface, divide the queues into two halves
# 2. Evenly distribute the vCPUs on NUMA0 to the first half of the queues
# 3. Evenly distribute the vCPUs on NUMA1 to the second half of the queues
# This function will have to be called once for each NUMA.
function set_xps_affinity() {
  local numa="$1"
  local cpus=("${@:2}")

  total_vcpus=${#cpus[@]}
  nics_string=$(get_network_interfaces)

  IFS=' ' read -r -a nics <<< "$nics_string"
  for nic in "${nics[@]}"; do
    tx_queue_count=$(ls -1 ${ROOT_DIR}sys/class/net/"$nic"/queues/ | grep tx | wc -l)

    if [[ "$machine_type" == *"a4x-"* ]]; then
      # All queues on a4x get the full mask.
      for (( queue=0; queue<tx_queue_count; queue++ )); do
        echo "${A4X_ALL_CPUS_MASK}" > "${ROOT_DIR}sys/class/net/$nic/queues/tx-$queue/xps_cpus"
      done
      continue
    fi

    if [[ $num_numa_nodes -le $tx_queue_count ]]; then
      # the number of queues to assign CPUs for this NUMA node.
      queues_per_numa=$(( tx_queue_count / num_numa_nodes ))

      # the number of CPUs to assign per queue
      cpus_per_queue=$(( total_vcpus / queues_per_numa))

      echo "nic=$nic tx_queue_count=$tx_queue_count queues_per_numa=$queues_per_numa cpus_per_queue=$cpus_per_queue"

      cpu_index=0
      queue_offset=$(( queues_per_numa*numa ))
      for (( queue=queue_offset; queue<queue_offset+queues_per_numa; queue+=1 )); do
        xps_path=${ROOT_DIR}sys/class/net/$nic/queues/tx-$queue/xps_cpus
        xps_cpus=""

        # Assign all the remaining CPUs to the last queue
        if [[ queue -eq $(( queue_offset + queues_per_numa - 1 )) ]]; then
          cpus_per_queue=$(( total_vcpus - cpu_index ))
        fi

        for (( i=0; i<cpus_per_queue; i+=1 )); do
          xps_cpus+="${cpus[cpu_index]},"
          cpu_index=$(( cpu_index + 1 ))
        done

        # remove the last ","
        xps_cpus="${xps_cpus%,}"
        cpu_mask=$(rangelist_to_bitmap "$xps_cpus" "$(nproc)")
        echo ${cpu_mask} > $xps_path
      done
    else
      # num_numa_nodes > tx_queue_count.
      # multiple NUMA nodes share a queue. We append to the mask.
      queue=$(( numa % tx_queue_count ))
      xps_path=${ROOT_DIR}sys/class/net/$nic/queues/tx-$queue/xps_cpus

      current_mask=$(cat "$xps_path" 2>/dev/null || echo "0")
      current_rangelist=$(bitmap_to_rangelist "$current_mask" 2>/dev/null || echo "")

      # Flatten cpus array to comma-separated list
      new_cpus_list=$(IFS=,; echo "${cpus[*]}")
      ranges="${current_rangelist},${new_cpus_list}"
      ranges="${ranges#,}"

      cpu_mask=$(rangelist_to_bitmap "$ranges" "$(nproc)")
      echo ${cpu_mask} > $xps_path
    fi
  done
}

echo "Running $(basename $0)."
machine_type=$(get_metadata "instance/machine-type")
echo "Machine type: $machine_type"

VIRTIO_NET_DEVS=${ROOT_DIR}sys/bus/virtio/drivers/virtio_net/virtio*
is_multinic_accelerator_platform
IS_MULTINIC_ACCELERATOR_PLATFORM=$?

# Loop through all the virtionet devices and enable multi-queue
if [ -x "$(command -v ethtool)" ]; then
  for dev in $VIRTIO_NET_DEVS; do
    ETH_DEVS=${dev}/net/*
    for eth_dev in $ETH_DEVS; do
      eth_dev=$(basename "$eth_dev")
      if ! errormsg=$(ethtool -l "$eth_dev" 2>&1); then
        echo "ethtool says that $eth_dev does not support virtionet multiqueue: $errormsg."
        continue
      fi
      num_max_channels=$(ethtool -l "$eth_dev" | grep -m 1 Combined | cut -f2)
      if [[ -n "${num_max_channels}" || "${num_max_channels}" -eq "1" ]]; then
        echo "num_max_channels is n/a, skipping set channels for $eth_dev"
        continue
      fi
      if is_decimal_int "$num_max_channels" && \
        set_channels "$eth_dev" "$num_max_channels"; then
        echo "Set channels for $eth_dev to $num_max_channels."
      else
        echo "Could not set channels for $eth_dev to $num_max_channels."
      fi
    done
  done
else
  echo "ethtool not found: cannot configure virtionet multiqueue."
fi

for dev in $VIRTIO_NET_DEVS
do
    dev=$(basename "$dev")
    irq_dir=${ROOT_DIR}proc/irq/*
    for irq in $irq_dir
    do
      smp_affinity="${irq}/smp_affinity_list"
      [ ! -f "${smp_affinity}" ] && continue
      # Classify this IRQ as virtionet intx, virtionet MSI-X, or non-virtionet
      # If the IRQ type is virtionet intx, a subdirectory with the same name as
      # the device will be present. If the IRQ type is virtionet MSI-X, then
      # a subdirectory of the form <device name>-<input|output>.N will exist.
      # In this case, N is the input (output) queue number, and is specified as
      # a decimal integer ranging from 0 to K - 1 where K is the number of
      # input (output) queues in the virtionet device.
      virtionet_intx_dir="${irq}/${dev}"
      virtionet_msix_dir_regex=".*/${dev}-(input|output)\.([0-9]+)$"
      if [ -d "${virtionet_intx_dir}" ]; then
        # All virtionet intx IRQs are delivered to CPU 0
        echo "Setting ${smp_affinity} to 01 for device ${dev}."
        echo "01" > "${smp_affinity}"
        continue
      fi
      # Not virtionet intx, probe for MSI-X
      virtionet_msix_found=0
      for entry in ${irq}/${dev}*; do
        if [[ "$entry" =~ ${virtionet_msix_dir_regex} ]]; then
          virtionet_msix_found=1
          queue_num=${BASH_REMATCH[2]}
        fi
      done
      affinity_hint="${irq}/affinity_hint"
      [ "$virtionet_msix_found" -eq 0 -o ! -f "${affinity_hint}" ] && continue

      # Set the IRQ CPU affinity to the virtionet-initialized affinity hint
      echo "Setting ${smp_affinity} to ${queue_num} for device ${dev}."
      echo "${queue_num}" > "${smp_affinity}"
      real_affinity=`cat ${smp_affinity}`
      echo "${smp_affinity}: real affinity ${real_affinity}"
    done
done

# Set smp_affinity properly for gvnic queues. '-ntfy-block.' is unique to gve
# and will not affect virtio queues.
for i in ${ROOT_DIR}proc/irq/*; do
  if ls ${i}/*-ntfy-block.* 1> /dev/null 2>&1; then
    if [ -f ${i}/affinity_hint ]; then
      echo Setting smp_affinity on ${i} to $(cat ${i}/affinity_hint)
      cp ${i}/affinity_hint ${i}/smp_affinity
    fi
  fi
done

if [[ ! $IS_MULTINIC_ACCELERATOR_PLATFORM == 0 ]]; then
  exit
fi

num_numa_nodes=$(ls -d ${ROOT_DIR}sys/devices/system/node/node* | wc -l)
echo "Found ${num_numa_nodes} NUMA nodes."

for ((node=0; node<num_numa_nodes; node++)); do
  ranges=$(get_vcpu_ranges "$node")
  dec_ranges=()
  unpack_cpu_ranges "${ranges}" dec_ranges

  echo -e "\nConfiguring XPS affinity for devices on NUMA ${node}"
  echo -e "vCPUs on NUMA${node} [${dec_ranges[@]}]"
  set_xps_affinity "$node" "${dec_ranges[@]}"
done

# Assign IRQ binding for network interfaces based on pci bus ordering.
# Avoid setting binding IRQ on vCPU 0 as it is a busy vCPU being heavily
# used by the system.
for ((node=0; node<num_numa_nodes; node++)); do
  ranges=$(get_vcpu_ranges "$node")

  declare -a node_irq_ranges=()
  unpack_cpu_ranges "${ranges}" node_irq_ranges

  if [[ $node -eq 0 ]]; then
    # Skip vCPU 0
    node_irq_ranges=("${node_irq_ranges[@]:1}")
  fi

  echo -e "\nSetting IRQ affinity with vCPUs on NUMA${node} [${node_irq_ranges[@]}]"
  bind_cores_index=0
  find ${ROOT_DIR}sys/class/net -type l | xargs -L 1 realpath | grep '/sys/devices/pci' | sort | xargs -L 1 basename | while read nic_name; do
    nic_numa_node=$(cat ${ROOT_DIR}sys/class/net/"$nic_name"/device/numa_node)
    if [[ $nic_numa_node -ne $node ]]; then
      continue
    fi

    # For non-gvnic/idpf devices (e.g. mlx5), the IRQ bindings will be handled by the device's driver.
    if is_gvnic "$nic_name"; then
      bind_cores_index=$(set_irq_range_gve "$nic_name" "$bind_cores_index" "${node_irq_ranges[@]}")
    elif is_idpf "$nic_name"; then
      bind_cores_index=$(set_irq_range_idpf "$nic_name" "$bind_cores_index" "${node_irq_ranges[@]}")
      if [[ $machine_type == *"a4x-maxgpu-"* ]]; then
        ethtool -G "$nic_name" rx "$A4X_RX_RING_LENGTH" tx "$A4X_TX_RING_LENGTH"
      fi
    else
      echo "$nic_name is not a gvnic/idpf device, not setting irq affinity on this device"
    fi
  done
done
