Files

333 lines
6.2 KiB
Go

package ping
import (
"context"
"errors"
"fmt"
"net"
"os"
"runtime"
"strings"
"sync"
"syscall"
"time"
"git.ktf.ninja/tabledevil/gists/projects/go-tools/go/goipgrep/internal/resolve"
)
type Options struct {
Mode string
Timeout time.Duration
Jobs int
}
func FilterIPs(ctx context.Context, ips []string, opts Options) ([]string, error) {
if len(ips) == 0 {
return nil, nil
}
mode, warn := selectMode(opts.Mode)
if warn != "" {
fmt.Fprintln(os.Stderr, "ipgrep:", warn)
}
j := opts.Jobs
if j <= 0 {
j = runtime.GOMAXPROCS(0)
}
in := make(chan string)
out := make(chan string)
var wg sync.WaitGroup
worker := func() {
defer wg.Done()
for ip := range in {
if reachable(ctx, ip, mode, opts.Timeout) {
select {
case out <- ip:
case <-ctx.Done():
return
}
}
}
}
wg.Add(j)
for i := 0; i < j; i++ {
go worker()
}
go func() {
defer close(in)
for _, ip := range ips {
select {
case in <- ip:
case <-ctx.Done():
return
}
}
}()
go func() {
wg.Wait()
close(out)
}()
var res []string
for ip := range out {
res = append(res, ip)
}
return res, nil
}
// FilterMACs resolves each MAC to one or more local-neighbor IPs then applies IP reachability.
// On non-Linux, neighbor resolution may be unsupported without external tools.
func FilterMACs(ctx context.Context, macs []string, opts Options) ([]string, error) {
tab, err := resolve.LoadNeighborTable()
if err != nil {
return nil, err
}
mode, warn := selectMode(opts.Mode)
if warn != "" {
fmt.Fprintln(os.Stderr, "ipgrep:", warn)
}
j := opts.Jobs
if j <= 0 {
j = runtime.GOMAXPROCS(0)
}
in := make(chan string)
out := make(chan string)
var wg sync.WaitGroup
worker := func() {
defer wg.Done()
for mac := range in {
ips := tab.ByMAC[mac]
ok := false
for _, ip := range ips {
if reachable(ctx, ip, mode, opts.Timeout) {
ok = true
break
}
}
if ok {
select {
case out <- mac:
case <-ctx.Done():
return
}
}
}
}
wg.Add(j)
for i := 0; i < j; i++ {
go worker()
}
go func() {
defer close(in)
for _, mac := range macs {
select {
case in <- mac:
case <-ctx.Done():
return
}
}
}()
go func() {
wg.Wait()
close(out)
}()
var res []string
for mac := range out {
res = append(res, mac)
}
return res, nil
}
func selectMode(mode string) (string, string) {
mode = strings.ToLower(strings.TrimSpace(mode))
switch mode {
case "icmp":
return "icmp", ""
case "tcp":
return "tcp", ""
case "auto", "":
if canICMP() {
return "icmp", ""
}
return "tcp", "ICMP not permitted/available; falling back to TCP probe (use --ping-mode=tcp to silence)"
default:
return "auto", ""
}
}
func canICMP() bool {
c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
if err == nil {
_ = c.Close()
return true
}
// Permission failures are common; everything else treat as unavailable too.
return false
}
func reachable(ctx context.Context, ip string, mode string, timeout time.Duration) bool {
switch mode {
case "icmp":
ok, err := icmpEcho(ctx, ip, timeout)
if err == nil {
return ok
}
// ICMP implementation is IPv4-only; use TCP probe for non-IPv4.
if err != nil && strings.Contains(err.Error(), "not an IPv4") {
return tcpProbe(ctx, ip, timeout)
}
if isPerm(err) {
// If user forced icmp, don't silently fall back.
return false
}
return false
case "tcp":
return tcpProbe(ctx, ip, timeout)
default:
return tcpProbe(ctx, ip, timeout)
}
}
func icmpEcho(ctx context.Context, dst string, timeout time.Duration) (bool, error) {
ip := net.ParseIP(dst)
if ip == nil || ip.To4() == nil {
return false, errors.New("not an IPv4 address")
}
c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return false, err
}
defer c.Close()
// Echo request: type 8, code 0.
id := os.Getpid() & 0xffff
seq := int(time.Now().UnixNano() & 0xffff)
msg := make([]byte, 8+16)
msg[0] = 8 // type
msg[1] = 0 // code
// checksum at [2:4]
msg[4] = byte(id >> 8)
msg[5] = byte(id)
msg[6] = byte(seq >> 8)
msg[7] = byte(seq)
copy(msg[8:], []byte("goipgrep-icmp\000"))
cs := checksum(msg)
msg[2] = byte(cs >> 8)
msg[3] = byte(cs)
deadline := time.Now().Add(timeout)
_ = c.SetDeadline(deadline)
dstAddr := &net.IPAddr{IP: ip}
if _, err := c.WriteTo(msg, dstAddr); err != nil {
return false, err
}
buf := make([]byte, 1500)
for {
select {
case <-ctx.Done():
return false, ctx.Err()
default:
}
n, _, err := c.ReadFrom(buf)
if err != nil {
return false, err
}
b := buf[:n]
// Some OSes may include IPv4 header; strip if present.
if len(b) >= 20 && (b[0]>>4) == 4 {
ihl := int(b[0]&0x0f) * 4
if ihl <= len(b) {
b = b[ihl:]
}
}
if len(b) < 8 {
continue
}
typ := b[0]
code := b[1]
if typ != 0 || code != 0 { // echo reply
continue
}
rid := int(b[4])<<8 | int(b[5])
rseq := int(b[6])<<8 | int(b[7])
if rid == id && rseq == seq {
return true, nil
}
}
}
func checksum(b []byte) uint16 {
var sum uint32
for i := 0; i+1 < len(b); i += 2 {
sum += uint32(b[i])<<8 | uint32(b[i+1])
}
if len(b)%2 == 1 {
sum += uint32(b[len(b)-1]) << 8
}
for (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
func tcpProbe(ctx context.Context, ip string, timeout time.Duration) bool {
ports := []int{443, 80}
d := net.Dialer{Timeout: timeout}
for _, p := range ports {
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", p))
c, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
_ = c.Close()
return true
}
if isConnRefused(err) {
return true
}
}
return false
}
func isPerm(err error) bool {
if err == nil {
return false
}
if errors.Is(err, os.ErrPermission) {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) {
if errors.Is(opErr.Err, syscall.EPERM) || errors.Is(opErr.Err, syscall.EACCES) {
return true
}
}
return false
}
func isConnRefused(err error) bool {
var opErr *net.OpError
if errors.As(err, &opErr) {
// On some platforms, opErr.Err is *os.SyscallError wrapping ECONNREFUSED.
var se *os.SyscallError
if errors.As(opErr.Err, &se) {
return errors.Is(se.Err, syscall.ECONNREFUSED)
}
return errors.Is(opErr.Err, syscall.ECONNREFUSED)
}
return false
}