rpc/server.go

350 lines
8.1 KiB
Go

package rpc
import (
"compress/gzip"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"runtime/debug"
"strconv"
"strings"
"git.bit5.ru/backend/errors"
"git.bit5.ru/backend/meta"
"github.com/go-logr/logr"
"github.com/google/uuid"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
)
const tracerName = "git.bit5.ru/backend/rpc"
var tracer = otel.Tracer(tracerName)
var DevOnlyErr = errors.New("Allowed in dev mode only")
const GenericErrorCode = -1
type ServerError struct {
Code int
Desc string
}
// implementing error interface
func (e ServerError) Error() string {
return fmt.Sprintf("code: %d, desc: %s", e.Code, e.Desc)
}
func NewServerError(code int) error {
return &ServerError{Code: code, Desc: ""}
}
// RPCResult struct to unify response from RPC handlers
type RPCResult struct {
Error error
Payload []byte
}
type RPCServer struct {
isDev bool
versionHeader string
gzipLenThreshold int
logger logr.Logger
newHandler func(ctx *RPCCtx) IRPCHandler
createRPC func(code uint32) (meta.IRPC, error)
}
type IRPCHandler interface {
BeforeExecute(p meta.IRPC) error
AfterExecute(p meta.IRPC, err error) error
Close()
}
type RPCCtx struct {
Ctx context.Context
Log logr.Logger
ClientVersion string
ClientIP net.IP
ClientPlatform int
IsDev bool
}
func NewRPCServer(logger logr.Logger,
isDev bool,
versionHeader string,
gzipLenThreshold int,
newHandler func(ctx *RPCCtx) IRPCHandler,
createRPC func(code uint32) (meta.IRPC, error)) *RPCServer {
return &RPCServer{
isDev: isDev,
versionHeader: versionHeader,
gzipLenThreshold: gzipLenThreshold,
logger: logger.WithName("[rpc]"),
newHandler: newHandler,
createRPC: createRPC}
}
func (server *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := tracer.Start(r.Context(), "RPCServer.ServeHTTP")
defer span.End()
w.Header().Set("Content-Type", "application/text; charset=utf-8")
canGzip := strings.Contains(r.Header.Get("Accept-Encoding"), "gzip")
versionHeader := r.Header.Get(server.versionHeader)
cltVersion, cltPlatform := parseVersionHeader(versionHeader)
cltIP := GetIP(r)
span.SetAttributes(
attribute.String("request.method", r.Method),
attribute.String("request.proto", r.Proto),
attribute.Int("request.proto_major", r.ProtoMajor),
attribute.Int("request.proto_minor", r.ProtoMinor),
attribute.Int64("request.content_length", r.ContentLength),
attribute.String("request.host", r.Host),
attribute.String("request.remote_addr", r.RemoteAddr),
attribute.String("request.request_uri", r.RequestURI),
attribute.String("client_ver", cltVersion),
attribute.Int("client_platform", cltPlatform),
attribute.String("client_ip", cltIP.String()),
attribute.Bool("can_gzip", canGzip),
)
if url := r.URL; url != nil {
span.SetAttributes(
attribute.String("request.url.scheme", url.Scheme),
attribute.String("request.url.host", url.Host),
attribute.String("request.url.path", url.Path),
attribute.String("request.url.raw_query", url.RawQuery),
)
}
server.logger.V(1).Info("New request", "client_ver", cltVersion)
payload, err := readRequestBody(ctx, r.Body)
if err != nil {
payloadLen := len(payload)
span.SetStatus(codes.Error, "Got error from readRequestBody.")
span.RecordError(err)
server.logger.Error(err, "Error reading body", "client_ver", cltVersion, "client_platform", cltPlatform, "client_ip", cltIP, "body_len", payloadLen)
return
}
logger := server.logger
if server.isDev {
reqId := uuid.New()
logger = logger.WithValues("uid", reqId.String())
}
result := server.serveRequest(ctx, logger, payload, cltVersion, cltPlatform, cltIP)
server.writeResponse(canGzip, w, result.Payload)
}
func readRequestBody(ctx context.Context, body io.Reader) ([]byte, error) {
_, span := tracer.Start(ctx, "readRequestBody")
defer span.End()
payload, err := ioutil.ReadAll(body)
span.SetAttributes(
attribute.Int("body_len", len(payload)),
)
if err != nil {
span.SetStatus(codes.Error, "Error reading body. Got error from ioutil.ReadAll.")
span.RecordError(err)
return payload, err
}
return payload, nil
}
func GetIP(r *http.Request) net.IP {
ipStr := GetIPStr(r)
return net.ParseIP(ipStr)
}
func GetIPStr(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ", ")[0]
}
ra, _, _ := net.SplitHostPort(r.RemoteAddr)
return ra
}
func (server *RPCServer) writeResponse(canGzip bool, w http.ResponseWriter, payload []byte) {
if len(payload) == 0 {
server.logger.Error(nil, "Nothing was written to response!")
return
}
var wr io.Writer = w
wrLen := 0
if server.isDev {
wr = &respWriterWithLen{w, &wrLen}
}
if canGzip && len(payload) > server.gzipLenThreshold {
w.Header().Set("Content-Encoding", "gzip")
server.logger.V(1).Info("Response is gzipped")
gz := gzip.NewWriter(wr)
defer gz.Close()
gz.Write(payload)
} else {
wr.Write(payload)
}
if server.isDev {
server.logger.V(1).Info("Response size", "len", wrLen, "orig", len(payload))
}
}
type respWriterWithLen struct {
http.ResponseWriter
Len *int
}
func (w *respWriterWithLen) Write(b []byte) (n int, err error) {
n, err = w.ResponseWriter.Write(b)
(*w.Len) += n
return
}
func (server *RPCServer) serveRequest(context context.Context,
logger logr.Logger,
request []byte,
clientVersion string,
clientPlatform int,
clientIP net.IP) (result RPCResult) {
result = RPCResult{}
defer func() {
if rec := recover(); rec != nil {
stackStr := errors.FormatPanicDebugStack(string(debug.Stack()))
// set error that returns after panic handling. See named return values
panicErr := errors.Errorf("[PANIC] %v %s", rec, stackStr)
// set result that returns after panic handling. See named return values
result.Error = panicErr
result.Payload = getResponseBytes(nil, panicErr)
logger.Error(nil, "RPC panic", "PANIC", rec, "client_ver", clientVersion)
}
}()
h := server.newHandler(&RPCCtx{context, logger, clientVersion, clientIP, clientPlatform, server.isDev})
defer h.Close()
p, err := server.executeRPC(h, request)
result.Error = err
result.Payload = getResponseBytes(p, err)
return
}
func (server *RPCServer) executeRPC(h IRPCHandler, request []byte) (meta.IRPC, error) {
p, err := DecodeRequestToRPC(request, server.createRPC)
if err != nil {
return nil, err
}
if err := h.BeforeExecute(p); err != nil {
err = h.AfterExecute(p, err)
//NOTE: returning constructed packet so that it can be inspected
return p, err
}
err = p.Execute(h)
err = h.AfterExecute(p, err)
if err != nil {
return p, err
}
return p, nil
}
func DecodeRequestToRPC(request []byte, createById meta.RPCFactory) (meta.IRPC, error) {
reader, err := meta.NewMsgpackReader(request)
if err != nil {
return nil, err
}
var code uint32
err = reader.BeginContainer("")
if err != nil {
return nil, err
}
err = reader.ReadU32(&code, "")
if err != nil {
return nil, err
}
rpc, err := createById(code)
if err != nil {
return nil, err
}
req := rpc.GetRequest()
err = req.Read(reader)
if err != nil {
err = errors.Wrapf(err, "Packet code: %d ", code)
return nil, err
}
return rpc, nil
}
func getResponseBytes(rpc meta.IRPC, err error) []byte {
writer := meta.NewMsgpackWriter()
writer.BeginContainer("")
if err == nil {
writer.WriteI32(rpc.GetCode(), "")
rpc.GetResponse().Write(writer)
} else {
var rpcErr *ServerError
if errors.As(err, &rpcErr) {
writer.WriteI32(int32(rpcErr.Code), "")
writer.WriteString(rpcErr.Desc, "")
} else {
writer.WriteI32(GenericErrorCode, "")
writer.WriteString(err.Error(), "")
}
}
writer.EndContainer()
bytes, _ := writer.GetData()
return bytes
}
func parseVersionHeader(header string) (string, int) {
parts := strings.Split(header, ";")
cltVersion := ""
cltPlatformStr := ""
cltPlatform := 0
if len(parts) > 0 {
cltVersion = parts[0]
if len(parts) > 1 {
cltPlatformStr = parts[1]
}
}
cltPlatform, _ = strconv.Atoi(cltPlatformStr)
return cltVersion, cltPlatform
}