350 lines
8.1 KiB
Go
350 lines
8.1 KiB
Go
package rpc
|
|
|
|
import (
|
|
"compress/gzip"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"git.bit5.ru/backend/errors"
|
|
"git.bit5.ru/backend/meta"
|
|
|
|
"github.com/go-logr/logr"
|
|
"github.com/google/uuid"
|
|
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
)
|
|
|
|
const tracerName = "git.bit5.ru/backend/rpc"
|
|
|
|
var tracer = otel.Tracer(tracerName)
|
|
|
|
var DevOnlyErr = errors.New("Allowed in dev mode only")
|
|
|
|
const GenericErrorCode = -1
|
|
|
|
type ServerError struct {
|
|
Code int
|
|
Desc string
|
|
}
|
|
|
|
// implementing error interface
|
|
func (e ServerError) Error() string {
|
|
return fmt.Sprintf("code: %d, desc: %s", e.Code, e.Desc)
|
|
}
|
|
|
|
func NewServerError(code int) error {
|
|
return &ServerError{Code: code, Desc: ""}
|
|
}
|
|
|
|
// RPCResult struct to unify response from RPC handlers
|
|
type RPCResult struct {
|
|
Error error
|
|
Payload []byte
|
|
}
|
|
|
|
type RPCServer struct {
|
|
isDev bool
|
|
versionHeader string
|
|
gzipLenThreshold int
|
|
logger logr.Logger
|
|
newHandler func(ctx *RPCCtx) IRPCHandler
|
|
createRPC func(code uint32) (meta.IRPC, error)
|
|
}
|
|
|
|
type IRPCHandler interface {
|
|
BeforeExecute(p meta.IRPC) error
|
|
AfterExecute(p meta.IRPC, err error) error
|
|
Close()
|
|
}
|
|
|
|
type RPCCtx struct {
|
|
Ctx context.Context
|
|
Log logr.Logger
|
|
ClientVersion string
|
|
ClientIP net.IP
|
|
ClientPlatform int
|
|
IsDev bool
|
|
}
|
|
|
|
func NewRPCServer(logger logr.Logger,
|
|
isDev bool,
|
|
versionHeader string,
|
|
gzipLenThreshold int,
|
|
newHandler func(ctx *RPCCtx) IRPCHandler,
|
|
createRPC func(code uint32) (meta.IRPC, error)) *RPCServer {
|
|
return &RPCServer{
|
|
isDev: isDev,
|
|
versionHeader: versionHeader,
|
|
gzipLenThreshold: gzipLenThreshold,
|
|
logger: logger.WithName("[rpc]"),
|
|
newHandler: newHandler,
|
|
createRPC: createRPC}
|
|
}
|
|
|
|
func (server *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
|
ctx, span := tracer.Start(r.Context(), "RPCServer.ServeHTTP")
|
|
defer span.End()
|
|
|
|
w.Header().Set("Content-Type", "application/text; charset=utf-8")
|
|
|
|
canGzip := strings.Contains(r.Header.Get("Accept-Encoding"), "gzip")
|
|
|
|
versionHeader := r.Header.Get(server.versionHeader)
|
|
cltVersion, cltPlatform := parseVersionHeader(versionHeader)
|
|
cltIP := GetIP(r)
|
|
|
|
span.SetAttributes(
|
|
attribute.String("request.method", r.Method),
|
|
attribute.String("request.proto", r.Proto),
|
|
attribute.Int("request.proto_major", r.ProtoMajor),
|
|
attribute.Int("request.proto_minor", r.ProtoMinor),
|
|
attribute.Int64("request.content_length", r.ContentLength),
|
|
attribute.String("request.host", r.Host),
|
|
attribute.String("request.remote_addr", r.RemoteAddr),
|
|
attribute.String("request.request_uri", r.RequestURI),
|
|
attribute.String("client_ver", cltVersion),
|
|
attribute.Int("client_platform", cltPlatform),
|
|
attribute.String("client_ip", cltIP.String()),
|
|
attribute.Bool("can_gzip", canGzip),
|
|
)
|
|
|
|
if url := r.URL; url != nil {
|
|
span.SetAttributes(
|
|
attribute.String("request.url.scheme", url.Scheme),
|
|
attribute.String("request.url.host", url.Host),
|
|
attribute.String("request.url.path", url.Path),
|
|
attribute.String("request.url.raw_query", url.RawQuery),
|
|
)
|
|
}
|
|
|
|
server.logger.V(1).Info("New request", "client_ver", cltVersion)
|
|
|
|
payload, err := readRequestBody(ctx, r.Body)
|
|
if err != nil {
|
|
payloadLen := len(payload)
|
|
span.SetStatus(codes.Error, "Got error from readRequestBody.")
|
|
span.RecordError(err)
|
|
server.logger.Error(err, "Error reading body", "client_ver", cltVersion, "client_platform", cltPlatform, "client_ip", cltIP, "body_len", payloadLen)
|
|
return
|
|
}
|
|
|
|
logger := server.logger
|
|
if server.isDev {
|
|
reqId := uuid.New()
|
|
logger = logger.WithValues("uid", reqId.String())
|
|
}
|
|
|
|
result := server.serveRequest(ctx, logger, payload, cltVersion, cltPlatform, cltIP)
|
|
|
|
server.writeResponse(canGzip, w, result.Payload)
|
|
}
|
|
|
|
func readRequestBody(ctx context.Context, body io.Reader) ([]byte, error) {
|
|
_, span := tracer.Start(ctx, "readRequestBody")
|
|
defer span.End()
|
|
|
|
payload, err := ioutil.ReadAll(body)
|
|
|
|
span.SetAttributes(
|
|
attribute.Int("body_len", len(payload)),
|
|
)
|
|
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, "Error reading body. Got error from ioutil.ReadAll.")
|
|
span.RecordError(err)
|
|
return payload, err
|
|
}
|
|
|
|
return payload, nil
|
|
}
|
|
|
|
func GetIP(r *http.Request) net.IP {
|
|
ipStr := GetIPStr(r)
|
|
return net.ParseIP(ipStr)
|
|
}
|
|
|
|
func GetIPStr(r *http.Request) string {
|
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
return ip
|
|
}
|
|
|
|
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
|
return strings.Split(ip, ", ")[0]
|
|
}
|
|
|
|
ra, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
return ra
|
|
}
|
|
|
|
func (server *RPCServer) writeResponse(canGzip bool, w http.ResponseWriter, payload []byte) {
|
|
if len(payload) == 0 {
|
|
server.logger.Error(nil, "Nothing was written to response!")
|
|
return
|
|
}
|
|
var wr io.Writer = w
|
|
wrLen := 0
|
|
if server.isDev {
|
|
wr = &respWriterWithLen{w, &wrLen}
|
|
}
|
|
|
|
if canGzip && len(payload) > server.gzipLenThreshold {
|
|
w.Header().Set("Content-Encoding", "gzip")
|
|
server.logger.V(1).Info("Response is gzipped")
|
|
gz := gzip.NewWriter(wr)
|
|
defer gz.Close()
|
|
gz.Write(payload)
|
|
} else {
|
|
wr.Write(payload)
|
|
}
|
|
|
|
if server.isDev {
|
|
server.logger.V(1).Info("Response size", "len", wrLen, "orig", len(payload))
|
|
}
|
|
}
|
|
|
|
type respWriterWithLen struct {
|
|
http.ResponseWriter
|
|
Len *int
|
|
}
|
|
|
|
func (w *respWriterWithLen) Write(b []byte) (n int, err error) {
|
|
n, err = w.ResponseWriter.Write(b)
|
|
(*w.Len) += n
|
|
return
|
|
}
|
|
|
|
func (server *RPCServer) serveRequest(context context.Context,
|
|
logger logr.Logger,
|
|
request []byte,
|
|
clientVersion string,
|
|
clientPlatform int,
|
|
clientIP net.IP) (result RPCResult) {
|
|
|
|
result = RPCResult{}
|
|
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
stackStr := errors.FormatPanicDebugStack(string(debug.Stack()))
|
|
// set error that returns after panic handling. See named return values
|
|
panicErr := errors.Errorf("[PANIC] %v %s", rec, stackStr)
|
|
|
|
// set result that returns after panic handling. See named return values
|
|
result.Error = panicErr
|
|
result.Payload = getResponseBytes(nil, panicErr)
|
|
|
|
logger.Error(nil, "RPC panic", "PANIC", rec, "client_ver", clientVersion)
|
|
}
|
|
}()
|
|
|
|
h := server.newHandler(&RPCCtx{context, logger, clientVersion, clientIP, clientPlatform, server.isDev})
|
|
defer h.Close()
|
|
|
|
p, err := server.executeRPC(h, request)
|
|
|
|
result.Error = err
|
|
result.Payload = getResponseBytes(p, err)
|
|
|
|
return
|
|
}
|
|
|
|
func (server *RPCServer) executeRPC(h IRPCHandler, request []byte) (meta.IRPC, error) {
|
|
p, err := DecodeRequestToRPC(request, server.createRPC)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := h.BeforeExecute(p); err != nil {
|
|
err = h.AfterExecute(p, err)
|
|
//NOTE: returning constructed packet so that it can be inspected
|
|
return p, err
|
|
}
|
|
err = p.Execute(h)
|
|
err = h.AfterExecute(p, err)
|
|
|
|
if err != nil {
|
|
return p, err
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
func DecodeRequestToRPC(request []byte, createById meta.RPCFactory) (meta.IRPC, error) {
|
|
reader, err := meta.NewMsgpackReader(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var code uint32
|
|
err = reader.BeginContainer("")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = reader.ReadU32(&code, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rpc, err := createById(code)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req := rpc.GetRequest()
|
|
err = req.Read(reader)
|
|
if err != nil {
|
|
err = errors.Wrapf(err, "Packet code: %d ", code)
|
|
return nil, err
|
|
}
|
|
return rpc, nil
|
|
}
|
|
|
|
func getResponseBytes(rpc meta.IRPC, err error) []byte {
|
|
writer := meta.NewMsgpackWriter()
|
|
writer.BeginContainer("")
|
|
|
|
if err == nil {
|
|
writer.WriteI32(rpc.GetCode(), "")
|
|
rpc.GetResponse().Write(writer)
|
|
} else {
|
|
var rpcErr *ServerError
|
|
if errors.As(err, &rpcErr) {
|
|
writer.WriteI32(int32(rpcErr.Code), "")
|
|
writer.WriteString(rpcErr.Desc, "")
|
|
} else {
|
|
writer.WriteI32(GenericErrorCode, "")
|
|
writer.WriteString(err.Error(), "")
|
|
}
|
|
}
|
|
|
|
writer.EndContainer()
|
|
|
|
bytes, _ := writer.GetData()
|
|
return bytes
|
|
}
|
|
|
|
func parseVersionHeader(header string) (string, int) {
|
|
parts := strings.Split(header, ";")
|
|
cltVersion := ""
|
|
cltPlatformStr := ""
|
|
cltPlatform := 0
|
|
if len(parts) > 0 {
|
|
cltVersion = parts[0]
|
|
if len(parts) > 1 {
|
|
cltPlatformStr = parts[1]
|
|
}
|
|
}
|
|
|
|
cltPlatform, _ = strconv.Atoi(cltPlatformStr)
|
|
|
|
return cltVersion, cltPlatform
|
|
}
|