333 lines
6.2 KiB
Go
333 lines
6.2 KiB
Go
package ping
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"git.ktf.ninja/tabledevil/gists/projects/go-tools/go/goipgrep/internal/resolve"
|
|
)
|
|
|
|
type Options struct {
|
|
Mode string
|
|
Timeout time.Duration
|
|
Jobs int
|
|
}
|
|
|
|
func FilterIPs(ctx context.Context, ips []string, opts Options) ([]string, error) {
|
|
if len(ips) == 0 {
|
|
return nil, nil
|
|
}
|
|
mode, warn := selectMode(opts.Mode)
|
|
if warn != "" {
|
|
fmt.Fprintln(os.Stderr, "ipgrep:", warn)
|
|
}
|
|
|
|
j := opts.Jobs
|
|
if j <= 0 {
|
|
j = runtime.GOMAXPROCS(0)
|
|
}
|
|
|
|
in := make(chan string)
|
|
out := make(chan string)
|
|
var wg sync.WaitGroup
|
|
|
|
worker := func() {
|
|
defer wg.Done()
|
|
for ip := range in {
|
|
if reachable(ctx, ip, mode, opts.Timeout) {
|
|
select {
|
|
case out <- ip:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
wg.Add(j)
|
|
for i := 0; i < j; i++ {
|
|
go worker()
|
|
}
|
|
|
|
go func() {
|
|
defer close(in)
|
|
for _, ip := range ips {
|
|
select {
|
|
case in <- ip:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
close(out)
|
|
}()
|
|
|
|
var res []string
|
|
for ip := range out {
|
|
res = append(res, ip)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// FilterMACs resolves each MAC to one or more local-neighbor IPs then applies IP reachability.
|
|
// On non-Linux, neighbor resolution may be unsupported without external tools.
|
|
func FilterMACs(ctx context.Context, macs []string, opts Options) ([]string, error) {
|
|
tab, err := resolve.LoadNeighborTable()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mode, warn := selectMode(opts.Mode)
|
|
if warn != "" {
|
|
fmt.Fprintln(os.Stderr, "ipgrep:", warn)
|
|
}
|
|
|
|
j := opts.Jobs
|
|
if j <= 0 {
|
|
j = runtime.GOMAXPROCS(0)
|
|
}
|
|
|
|
in := make(chan string)
|
|
out := make(chan string)
|
|
var wg sync.WaitGroup
|
|
|
|
worker := func() {
|
|
defer wg.Done()
|
|
for mac := range in {
|
|
ips := tab.ByMAC[mac]
|
|
ok := false
|
|
for _, ip := range ips {
|
|
if reachable(ctx, ip, mode, opts.Timeout) {
|
|
ok = true
|
|
break
|
|
}
|
|
}
|
|
if ok {
|
|
select {
|
|
case out <- mac:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
wg.Add(j)
|
|
for i := 0; i < j; i++ {
|
|
go worker()
|
|
}
|
|
|
|
go func() {
|
|
defer close(in)
|
|
for _, mac := range macs {
|
|
select {
|
|
case in <- mac:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
close(out)
|
|
}()
|
|
|
|
var res []string
|
|
for mac := range out {
|
|
res = append(res, mac)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func selectMode(mode string) (string, string) {
|
|
mode = strings.ToLower(strings.TrimSpace(mode))
|
|
switch mode {
|
|
case "icmp":
|
|
return "icmp", ""
|
|
case "tcp":
|
|
return "tcp", ""
|
|
case "auto", "":
|
|
if canICMP() {
|
|
return "icmp", ""
|
|
}
|
|
return "tcp", "ICMP not permitted/available; falling back to TCP probe (use --ping-mode=tcp to silence)"
|
|
default:
|
|
return "auto", ""
|
|
}
|
|
}
|
|
|
|
func canICMP() bool {
|
|
c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
|
|
if err == nil {
|
|
_ = c.Close()
|
|
return true
|
|
}
|
|
// Permission failures are common; everything else treat as unavailable too.
|
|
return false
|
|
}
|
|
|
|
func reachable(ctx context.Context, ip string, mode string, timeout time.Duration) bool {
|
|
switch mode {
|
|
case "icmp":
|
|
ok, err := icmpEcho(ctx, ip, timeout)
|
|
if err == nil {
|
|
return ok
|
|
}
|
|
// ICMP implementation is IPv4-only; use TCP probe for non-IPv4.
|
|
if err != nil && strings.Contains(err.Error(), "not an IPv4") {
|
|
return tcpProbe(ctx, ip, timeout)
|
|
}
|
|
if isPerm(err) {
|
|
// If user forced icmp, don't silently fall back.
|
|
return false
|
|
}
|
|
return false
|
|
case "tcp":
|
|
return tcpProbe(ctx, ip, timeout)
|
|
default:
|
|
return tcpProbe(ctx, ip, timeout)
|
|
}
|
|
}
|
|
|
|
func icmpEcho(ctx context.Context, dst string, timeout time.Duration) (bool, error) {
|
|
ip := net.ParseIP(dst)
|
|
if ip == nil || ip.To4() == nil {
|
|
return false, errors.New("not an IPv4 address")
|
|
}
|
|
|
|
c, err := net.ListenPacket("ip4:icmp", "0.0.0.0")
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
defer c.Close()
|
|
|
|
// Echo request: type 8, code 0.
|
|
id := os.Getpid() & 0xffff
|
|
seq := int(time.Now().UnixNano() & 0xffff)
|
|
msg := make([]byte, 8+16)
|
|
msg[0] = 8 // type
|
|
msg[1] = 0 // code
|
|
// checksum at [2:4]
|
|
msg[4] = byte(id >> 8)
|
|
msg[5] = byte(id)
|
|
msg[6] = byte(seq >> 8)
|
|
msg[7] = byte(seq)
|
|
copy(msg[8:], []byte("goipgrep-icmp\000"))
|
|
cs := checksum(msg)
|
|
msg[2] = byte(cs >> 8)
|
|
msg[3] = byte(cs)
|
|
|
|
deadline := time.Now().Add(timeout)
|
|
_ = c.SetDeadline(deadline)
|
|
|
|
dstAddr := &net.IPAddr{IP: ip}
|
|
if _, err := c.WriteTo(msg, dstAddr); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
buf := make([]byte, 1500)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return false, ctx.Err()
|
|
default:
|
|
}
|
|
n, _, err := c.ReadFrom(buf)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
b := buf[:n]
|
|
// Some OSes may include IPv4 header; strip if present.
|
|
if len(b) >= 20 && (b[0]>>4) == 4 {
|
|
ihl := int(b[0]&0x0f) * 4
|
|
if ihl <= len(b) {
|
|
b = b[ihl:]
|
|
}
|
|
}
|
|
if len(b) < 8 {
|
|
continue
|
|
}
|
|
typ := b[0]
|
|
code := b[1]
|
|
if typ != 0 || code != 0 { // echo reply
|
|
continue
|
|
}
|
|
rid := int(b[4])<<8 | int(b[5])
|
|
rseq := int(b[6])<<8 | int(b[7])
|
|
if rid == id && rseq == seq {
|
|
return true, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func checksum(b []byte) uint16 {
|
|
var sum uint32
|
|
for i := 0; i+1 < len(b); i += 2 {
|
|
sum += uint32(b[i])<<8 | uint32(b[i+1])
|
|
}
|
|
if len(b)%2 == 1 {
|
|
sum += uint32(b[len(b)-1]) << 8
|
|
}
|
|
for (sum >> 16) != 0 {
|
|
sum = (sum & 0xffff) + (sum >> 16)
|
|
}
|
|
return ^uint16(sum)
|
|
}
|
|
|
|
func tcpProbe(ctx context.Context, ip string, timeout time.Duration) bool {
|
|
ports := []int{443, 80}
|
|
d := net.Dialer{Timeout: timeout}
|
|
for _, p := range ports {
|
|
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", p))
|
|
c, err := d.DialContext(ctx, "tcp", addr)
|
|
if err == nil {
|
|
_ = c.Close()
|
|
return true
|
|
}
|
|
if isConnRefused(err) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isPerm(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, os.ErrPermission) {
|
|
return true
|
|
}
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
if errors.Is(opErr.Err, syscall.EPERM) || errors.Is(opErr.Err, syscall.EACCES) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isConnRefused(err error) bool {
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) {
|
|
// On some platforms, opErr.Err is *os.SyscallError wrapping ECONNREFUSED.
|
|
var se *os.SyscallError
|
|
if errors.As(opErr.Err, &se) {
|
|
return errors.Is(se.Err, syscall.ECONNREFUSED)
|
|
}
|
|
return errors.Is(opErr.Err, syscall.ECONNREFUSED)
|
|
}
|
|
return false
|
|
}
|