#!/bin/sh

find_nvcc() {
  CC=$(command -v nvcc 2>/dev/null) && return
  CC="$CUDA_PATH/bin/nvcc"
  [ -x "$CC" ] && return
  CC="/opt/cuda/bin/nvcc"
  [ -x "$CC" ] && return
  CC="/usr/local/cuda/bin/nvcc"
  [ -x "$CC" ] && return
  return 1
}

find_hipcc() {
  CC=$(command -v hipcc 2>/dev/null) && return
  CC="$HIP_PATH/bin/hipcc"
  [ -x "$CC" ] && return
  CC="/opt/rocm/bin/hipcc"
  [ -x "$CC" ] && return
  CC="/usr/local/rocm/bin/hipcc"
  [ -x "$CC" ] && return
  return 1
}

if find_hipcc; then
  VENDOR=AMD
  FLAGS=
elif find_nvcc; then
  VENDOR=NVIDIA
  FLAGS="--forward-unknown-to-host-compiler"
else
  echo 'error: need either hipcc (AMD) or nvcc (NVIDIA) on $PATH' >&2
  exit 1
fi

FIRST=1
for x; do
  if [ $FIRST -eq 1 ]; then
    set --
    FIRST=0
  fi
  if [ $VENDOR = AMD ]; then
    if [ x"$x" = x"-lcublas" ]; then
      set -- "$@" -lhipblas -lrocblas
      continue
    elif [ x"$x" = x"--use_fast_math" ]; then
      continue
    fi
  fi
  set -- "$@" "$x"
done

exec "$CC" $FLAGS "$@"
