rclone/vendor/github.com/anacrolix/dms/ssdp/ssdp.go
2019-01-09 19:14:14 +00:00

331 lines
6.9 KiB
Go

package ssdp
import (
"bufio"
"bytes"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
"golang.org/x/net/ipv4"
)
const (
AddrString = "239.255.255.250:1900"
rootDevice = "upnp:rootdevice"
aliveNTS = "ssdp:alive"
byebyeNTS = "ssdp:byebye"
)
var (
NetAddr *net.UDPAddr
)
func init() {
var err error
NetAddr, err = net.ResolveUDPAddr("udp4", AddrString)
if err != nil {
log.Panicf("Could not resolve %s: %s", AddrString, err)
}
}
type badStringError struct {
what string
str string
}
func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
func ReadRequest(b *bufio.Reader) (req *http.Request, err error) {
tp := textproto.NewReader(b)
var s string
if s, err = tp.ReadLine(); err != nil {
return nil, err
}
defer func() {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
}()
var f []string
// TODO a split that only allows N values?
if f = strings.SplitN(s, " ", 3); len(f) < 3 {
return nil, &badStringError{"malformed request line", s}
}
if f[1] != "*" {
return nil, &badStringError{"bad URL request", f[1]}
}
req = &http.Request{
Method: f[0],
}
var ok bool
if req.ProtoMajor, req.ProtoMinor, ok = http.ParseHTTPVersion(strings.TrimSpace(f[2])); !ok {
return nil, &badStringError{"malformed HTTP version", f[2]}
}
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
return nil, err
}
req.Header = http.Header(mimeHeader)
return
}
type Server struct {
conn *net.UDPConn
Interface net.Interface
Server string
Services []string
Devices []string
Location func(net.IP) string
UUID string
NotifyInterval time.Duration
closed chan struct{}
}
func makeConn(ifi net.Interface) (ret *net.UDPConn, err error) {
ret, err = net.ListenMulticastUDP("udp", &ifi, NetAddr)
if err != nil {
return
}
p := ipv4.NewPacketConn(ret)
if err := p.SetMulticastTTL(2); err != nil {
log.Println(err)
}
if err := p.SetMulticastLoopback(true); err != nil {
log.Println(err)
}
return
}
func (me *Server) serve() {
for {
b := make([]byte, me.Interface.MTU)
n, addr, err := me.conn.ReadFromUDP(b)
select {
case <-me.closed:
return
default:
}
if err != nil {
log.Printf("error reading from UDP socket: %s", err)
break
}
go me.handle(b[:n], addr)
}
}
func (me *Server) Init() (err error) {
me.closed = make(chan struct{})
me.conn, err = makeConn(me.Interface)
return
}
func (me *Server) Close() {
close(me.closed)
me.sendByeBye()
me.conn.Close()
}
func (me *Server) Serve() (err error) {
go me.serve()
for {
addrs, err := me.Interface.Addrs()
if err != nil {
return err
}
for _, addr := range addrs {
ip := func() net.IP {
switch val := addr.(type) {
case *net.IPNet:
return val.IP
case *net.IPAddr:
return val.IP
}
panic(fmt.Sprint("unexpected addr type:", addr))
}()
extraHdrs := [][2]string{
{"CACHE-CONTROL", fmt.Sprintf("max-age=%d", 5*me.NotifyInterval/2/time.Second)},
{"LOCATION", me.Location(ip)},
}
me.notifyAll(aliveNTS, extraHdrs)
}
time.Sleep(me.NotifyInterval)
}
}
func (me *Server) usnFromTarget(target string) string {
if target == me.UUID {
return target
}
return me.UUID + "::" + target
}
func (me *Server) makeNotifyMessage(target, nts string, extraHdrs [][2]string) []byte {
lines := [...][2]string{
{"HOST", AddrString},
{"NT", target},
{"NTS", nts},
{"SERVER", me.Server},
{"USN", me.usnFromTarget(target)},
}
buf := &bytes.Buffer{}
fmt.Fprint(buf, "NOTIFY * HTTP/1.1\r\n")
writeHdr := func(keyValue [2]string) {
fmt.Fprintf(buf, "%s: %s\r\n", keyValue[0], keyValue[1])
}
for _, pair := range lines {
writeHdr(pair)
}
for _, pair := range extraHdrs {
writeHdr(pair)
}
fmt.Fprint(buf, "\r\n")
return buf.Bytes()
}
func (me *Server) send(buf []byte, addr *net.UDPAddr) {
if n, err := me.conn.WriteToUDP(buf, addr); err != nil {
log.Printf("error writing to UDP socket: %s", err)
} else if n != len(buf) {
log.Printf("short write: %d/%d bytes", n, len(buf))
}
}
func (me *Server) delayedSend(delay time.Duration, buf []byte, addr *net.UDPAddr) {
go func() {
select {
case <-time.After(delay):
me.send(buf, addr)
case <-me.closed:
}
}()
}
func (me *Server) log(args ...interface{}) {
args = append([]interface{}{me.Interface.Name + ":"}, args...)
log.Print(args...)
}
func (me *Server) sendByeBye() {
for _, type_ := range me.allTypes() {
buf := me.makeNotifyMessage(type_, byebyeNTS, nil)
me.send(buf, NetAddr)
}
}
func (me *Server) notifyAll(nts string, extraHdrs [][2]string) {
for _, type_ := range me.allTypes() {
buf := me.makeNotifyMessage(type_, nts, extraHdrs)
delay := time.Duration(rand.Int63n(int64(100 * time.Millisecond)))
me.delayedSend(delay, buf, NetAddr)
}
}
func (me *Server) allTypes() (ret []string) {
for _, a := range [][]string{
{rootDevice, me.UUID},
me.Devices,
me.Services,
} {
ret = append(ret, a...)
}
return
}
func (me *Server) handle(buf []byte, sender *net.UDPAddr) {
req, err := ReadRequest(bufio.NewReader(bytes.NewReader(buf)))
if err != nil {
log.Println(err)
return
}
if req.Method != "M-SEARCH" || req.Header.Get("man") != `"ssdp:discover"` {
return
}
var mx uint
if req.Header.Get("Host") == AddrString {
mxHeader := req.Header.Get("mx")
i, err := strconv.ParseUint(mxHeader, 0, 0)
if err != nil {
log.Printf("Invalid mx header %q: %s", mxHeader, err)
return
}
mx = uint(i)
} else {
mx = 1
}
types := func(st string) []string {
if st == "ssdp:all" {
return me.allTypes()
}
for _, t := range me.allTypes() {
if t == st {
return []string{t}
}
}
return nil
}(req.Header.Get("st"))
for _, ip := range func() (ret []net.IP) {
addrs, err := me.Interface.Addrs()
if err != nil {
panic(err)
}
for _, addr := range addrs {
if ip, ok := func() (net.IP, bool) {
switch data := addr.(type) {
case *net.IPNet:
if data.Contains(sender.IP) {
return data.IP, true
}
return nil, false
case *net.IPAddr:
return data.IP, true
}
panic(addr)
}(); ok {
ret = append(ret, ip)
}
}
return
}() {
for _, type_ := range types {
resp := me.makeResponse(ip, type_, req)
delay := time.Duration(rand.Int63n(int64(time.Second) * int64(mx)))
me.delayedSend(delay, resp, sender)
}
}
}
func (me *Server) makeResponse(ip net.IP, targ string, req *http.Request) (ret []byte) {
resp := &http.Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Request: req,
}
for _, pair := range [...][2]string{
{"CACHE-CONTROL", fmt.Sprintf("max-age=%d", 5*me.NotifyInterval/2/time.Second)},
{"EXT", ""},
{"LOCATION", me.Location(ip)},
{"SERVER", me.Server},
{"ST", targ},
{"USN", me.usnFromTarget(targ)},
} {
resp.Header.Set(pair[0], pair[1])
}
buf := &bytes.Buffer{}
if err := resp.Write(buf); err != nil {
panic(err)
}
return buf.Bytes()
}