// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.

package rpc

import (
	"context"
	"crypto/tls"
	"net"
	"strings"
	"sync"
	"time"

	"github.com/zeebo/errs"
	"go.uber.org/zap"

	"storj.io/common/memory"
	"storj.io/common/netutil"
	"storj.io/common/pb"
	"storj.io/common/peertls/tlsopts"
	"storj.io/common/rpc/rpcpool"
	"storj.io/common/rpc/rpctracing"
	"storj.io/common/storj"
	"storj.io/drpc"
	"storj.io/drpc/drpcconn"
	"storj.io/drpc/drpcmanager"
	"storj.io/drpc/drpcstream"
)

// NewDefaultManagerOptions returns the default options we use for drpc managers.
func NewDefaultManagerOptions() drpcmanager.Options {
	return drpcmanager.Options{
		WriterBufferSize: 1024,
		Stream: drpcstream.Options{
			SplitSize: (4096 * 2) - 256,
		},
	}
}

// Dialer holds configuration for dialing.
type Dialer struct {
	// TLSOptions controls the tls options for dialing. If it is nil, only
	// insecure connections can be made.
	TLSOptions *tlsopts.Options

	// DialTimeout causes all the tcp dials to error if they take longer
	// than it if it is non-zero.
	DialTimeout time.Duration

	// DialLatency sleeps this amount if it is non-zero before every dial.
	// The timeout runs while the sleep is happening.
	DialLatency time.Duration

	// TransferRate limits all read/write operations to go slower than
	// the size per second if it is non-zero.
	TransferRate memory.Size

	// PoolOptions controls options for the connection pool.
	PoolOptions rpcpool.Options

	// ConnectionOptions controls the options that we pass to drpc connections.
	ConnectionOptions drpcconn.Options

	// TCPUserTimeout controls what setting to use for the TCP_USER_TIMEOUT
	// socket option on dialed connections. Only valid on linux. Only set
	// if positive.
	TCPUserTimeout time.Duration
}

// NewDefaultDialer returns a Dialer with default timeouts set.
func NewDefaultDialer(tlsOptions *tlsopts.Options) Dialer {
	return Dialer{
		TLSOptions:     tlsOptions,
		DialTimeout:    20 * time.Second,
		TCPUserTimeout: 15 * time.Minute,
		PoolOptions: rpcpool.Options{
			Capacity:       5,
			IdleExpiration: 2 * time.Minute,
		},
		ConnectionOptions: drpcconn.Options{
			Manager: NewDefaultManagerOptions(),
		},
	}
}

// dialContext does a raw tcp dial to the address and wraps the connection with the
// provided timeout.
func (d Dialer) dialContext(ctx context.Context, address string) (net.Conn, error) {
	if d.DialLatency > 0 {
		timer := time.NewTimer(d.DialLatency)
		select {
		case <-timer.C:
		case <-ctx.Done():
			timer.Stop()
			return nil, Error.Wrap(ctx.Err())
		}
	}

	conn, err := new(net.Dialer).DialContext(ctx, "tcp", address)
	if err != nil {
		// N.B. this error is not wrapped on purpose! grpc code cares about inspecting
		// it and it's not smart enough to attempt to do any unwrapping. :( Additionally
		// DialContext does not return an error that can be inspected easily to see if it
		// came from the context being canceled. Thus, we do this racy thing where if the
		// context is canceled at this point, we return it, rather than return the error
		// from dialing. It's a slight lie, but arguably still correct because the cancel
		// must be racing with the dial anyway.
		select {
		case <-ctx.Done():
			return nil, ctx.Err()
		default:
			return nil, err
		}
	}

	if tcpconn, ok := conn.(*net.TCPConn); d.TCPUserTimeout > 0 && ok {
		if err := netutil.SetUserTimeout(tcpconn, d.TCPUserTimeout); err != nil {
			return nil, errs.Combine(Error.Wrap(err), Error.Wrap(conn.Close()))
		}
	}

	return &timedConn{
		Conn: netutil.TrackClose(conn),
		rate: d.TransferRate,
	}, nil
}

// DialNode creates an rpc connection to the specified node.
func (d Dialer) DialNode(ctx context.Context, node *pb.Node) (_ *Conn, err error) {
	if node == nil {
		return nil, Error.New("node is nil")
	}

	defer mon.Task()(&ctx, "node: "+node.Id.String()[0:8])(&err)

	if d.TLSOptions == nil {
		return nil, Error.New("tls options not set when required for this dial")
	}

	return d.dial(ctx, node.GetAddress().GetAddress(), d.TLSOptions.ClientTLSConfig(node.Id))
}

// DialAddressID dials to the specified address and asserts it has the given node id.
func (d Dialer) DialAddressID(ctx context.Context, address string, id storj.NodeID) (_ *Conn, err error) {
	defer mon.Task()(&ctx)(&err)

	if d.TLSOptions == nil {
		return nil, Error.New("tls options not set when required for this dial")
	}

	return d.dial(ctx, address, d.TLSOptions.ClientTLSConfig(id))
}

// DialAddressInsecureBestEffort is like DialAddressInsecure but tries to dial a node securely if
// it can.
//
// nodeURL is like a storj.NodeURL but (a) requires an address and (b) does not require a
// full node id and will work with just a node prefix. The format is either:
//  * node_host:node_port
//  * node_id_prefix@node_host:node_port
// Examples:
//  * 33.20.0.1:7777
//  * [2001:db8:1f70::999:de8:7648:6e8]:7777
//  * 12vha9oTFnerx@33.20.0.1:7777
//  * 12vha9oTFnerx@[2001:db8:1f70::999:de8:7648:6e8]:7777
//
// DialAddressInsecureBestEffort:
//  * will use a node id if provided in the nodeURL paramenter
//  * will otherwise look up the node address in a known map of node address to node ids and use
// 		the remembered node id.
//  * will otherwise dial insecurely
func (d Dialer) DialAddressInsecureBestEffort(ctx context.Context, nodeURL string) (_ *Conn, err error) {
	defer mon.Task()(&ctx)(&err)

	if d.TLSOptions == nil {
		return nil, Error.New("tls options not set when required for this dial")
	}

	var nodeIDPrefix, nodeAddress string
	parts := strings.Split(nodeURL, "@")
	switch len(parts) {
	default:
		return nil, Error.New("malformed node url: %q", nodeURL)
	case 1:
		nodeAddress = parts[0]
	case 2:
		nodeIDPrefix, nodeAddress = parts[0], parts[1]
	}

	if len(nodeIDPrefix) > 0 {
		return d.dial(ctx, nodeAddress, d.TLSOptions.ClientTLSConfigPrefix(nodeIDPrefix))
	}

	if nodeID, found := KnownNodeID(nodeAddress); found {
		return d.dial(ctx, nodeAddress, d.TLSOptions.ClientTLSConfig(nodeID))
	}

	zap.L().Warn(`Unknown node id for address. Specify node id in the form "node_id@node_host:node_port" for added security`,
		zap.String("Address", nodeAddress),
	)
	return d.dial(ctx, nodeAddress, d.TLSOptions.UnverifiedClientTLSConfig())
}

// DialAddressInsecure dials to the specified address and does not check the node id.
func (d Dialer) DialAddressInsecure(ctx context.Context, address string) (_ *Conn, err error) {
	defer mon.Task()(&ctx)(&err)

	if d.TLSOptions == nil {
		return nil, Error.New("tls options not set when required for this dial")
	}

	return d.dial(ctx, address, d.TLSOptions.UnverifiedClientTLSConfig())
}

// DialAddressUnencrypted dials to the specified address without tls.
func (d Dialer) DialAddressUnencrypted(ctx context.Context, address string) (_ *Conn, err error) {
	defer mon.Task()(&ctx)(&err)

	return d.dialUnencrypted(ctx, address)
}

// drpcHeader is the first bytes we send on a connection so that the remote
// knows to expect drpc on the wire instead of grpc.
const drpcHeader = "DRPC!!!1"

// dial performs the dialing to the drpc endpoint with tls.
func (d Dialer) dial(ctx context.Context, address string, tlsConfig *tls.Config) (_ *Conn, err error) {
	defer mon.Task()(&ctx)(&err)

	// include the timeout here so that it includes all aspects of the dial
	if d.DialTimeout > 0 {
		var cancel func()
		ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
		defer cancel()
	}

	pool := rpcpool.New(d.PoolOptions, func(ctx context.Context) (drpc.Transport, error) {
		return d.dialTransport(ctx, address, tlsConfig)
	})

	conn, err := d.dialTransport(ctx, address, tlsConfig)
	if err != nil {
		return nil, err
	}
	state := conn.ConnectionState()

	if err := pool.Put(drpcconn.New(conn)); err != nil {
		return nil, err
	}

	return &Conn{
		state: state,
		Conn:  rpctracing.NewTracingWrapper(pool),
	}, nil
}

// dialTransport performs dialing to the drpc endpoint with tls.
func (d Dialer) dialTransport(ctx context.Context, address string, tlsConfig *tls.Config) (_ *tlsConnWrapper, err error) {
	defer mon.Task()(&ctx)(&err)

	// open the tcp socket to the address
	rawConn, err := d.dialContext(ctx, address)
	if err != nil {
		return nil, Error.Wrap(err)
	}
	rawConn = newDrpcHeaderConn(rawConn)

	// perform the handshake racing with the context closing. we use a buffer
	// of size 1 so that the handshake can proceed even if no one is reading.
	errCh := make(chan error, 1)
	conn := tls.Client(rawConn, tlsConfig)
	go func() { errCh <- conn.Handshake() }()

	// see which wins and close the raw conn if there was any error. we can't
	// close the tls connection concurrently with handshakes or it sometimes
	// will panic. cool, huh?
	select {
	case <-ctx.Done():
		err = ctx.Err()
	case err = <-errCh:
	}
	if err != nil {
		_ = rawConn.Close()
		return nil, Error.Wrap(err)
	}

	return &tlsConnWrapper{
		Conn:       conn,
		underlying: rawConn,
	}, nil
}

// dialUnencrypted performs dialing to the drpc endpoint with no tls.
func (d Dialer) dialUnencrypted(ctx context.Context, address string) (_ *Conn, err error) {
	defer mon.Task()(&ctx)(&err)

	// include the timeout here so that it includes all aspects of the dial
	if d.DialTimeout > 0 {
		var cancel func()
		ctx, cancel = context.WithTimeout(ctx, d.DialTimeout)
		defer cancel()
	}

	conn := rpcpool.New(d.PoolOptions, func(ctx context.Context) (drpc.Transport, error) {
		return d.dialTransportUnencrypted(ctx, address)
	})
	return &Conn{
		Conn: rpctracing.NewTracingWrapper(conn),
	}, nil
}

// dialTransportUnencrypted performs dialing to the drpc endpoint with no tls.
func (d Dialer) dialTransportUnencrypted(ctx context.Context, address string) (_ net.Conn, err error) {
	defer mon.Task()(&ctx)(&err)

	// open the tcp socket to the address
	conn, err := d.dialContext(ctx, address)
	if err != nil {
		return nil, Error.Wrap(err)
	}

	return newDrpcHeaderConn(conn), nil
}

// tlsConnWrapper is a wrapper around a *tls.Conn that calls Close on the
// underlying connection when closed rather than trying to send a
// notification to the other side which may block forever.
type tlsConnWrapper struct {
	*tls.Conn
	underlying net.Conn
}

// Close closes the underlying connection
func (t *tlsConnWrapper) Close() error { return t.underlying.Close() }

// drpcHeaderConn fulfills the net.Conn interface. On the first call to Write
// it will write the drpcHeader.
type drpcHeaderConn struct {
	net.Conn
	once sync.Once
}

// newDrpcHeaderConn returns a new *drpcHeaderConn
func newDrpcHeaderConn(conn net.Conn) *drpcHeaderConn {
	return &drpcHeaderConn{
		Conn: conn,
	}
}

// Write will write buf to the underlying conn. If this is the first time Write
// is called it will prepend the drpcHeader to the beginning of the write.
func (d *drpcHeaderConn) Write(buf []byte) (n int, err error) {
	var didOnce bool
	d.once.Do(func() {
		didOnce = true
		header := []byte(drpcHeader)
		n, err = d.Conn.Write(append(header, buf...))
	})
	if didOnce {
		n -= len(drpcHeader)
		if n < 0 {
			n = 0
		}
		return n, err
	}
	return d.Conn.Write(buf)
}