rpc/server.go

292 lines
6.4 KiB
Go
Raw Permalink Normal View History

2022-11-15 15:29:42 +03:00
package rpc
import (
"compress/gzip"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"runtime/debug"
"strconv"
"strings"
"git.bit5.ru/backend/errors"
"git.bit5.ru/backend/meta"
"github.com/go-logr/logr"
"github.com/google/uuid"
)
var DevOnlyErr = errors.New("Allowed in dev mode only")
const GenericErrorCode = -1
type ServerError struct {
Code int
Desc string
}
// implementing error interface
func (e ServerError) Error() string {
return fmt.Sprintf("code: %d, desc: %s", e.Code, e.Desc)
}
func NewServerError(code int) error {
return &ServerError{Code: code, Desc: ""}
}
// RPCResult struct to unify response from RPC handlers
type RPCResult struct {
Error error
Payload []byte
}
type RPCServer struct {
2022-11-15 17:53:34 +03:00
isDev bool
versionHeader string
gzipLenThreshold int
logger logr.Logger
newHandler func(ctx *RPCCtx) IRPCHandler
createRPC func(code uint32) (meta.IRPC, error)
2022-11-15 15:29:42 +03:00
}
type IRPCHandler interface {
BeforeExecute(p meta.IRPC) error
AfterExecute(p meta.IRPC, err error) error
2022-11-15 17:53:34 +03:00
Close()
2022-11-15 15:29:42 +03:00
}
type RPCCtx struct {
Ctx context.Context
2022-11-15 17:53:34 +03:00
Log logr.Logger
2022-11-15 15:29:42 +03:00
ClientVersion string
ClientIP net.IP
ClientPlatform int
2022-11-15 17:53:34 +03:00
IsDev bool
2022-11-15 15:29:42 +03:00
}
func NewRPCServer(logger logr.Logger,
2022-11-15 17:53:34 +03:00
isDev bool,
versionHeader string,
gzipLenThreshold int,
newHandler func(ctx *RPCCtx) IRPCHandler,
createRPC func(code uint32) (meta.IRPC, error)) *RPCServer {
return &RPCServer{
isDev: isDev,
versionHeader: versionHeader,
gzipLenThreshold: gzipLenThreshold,
logger: logger.WithName("[rpc]"),
newHandler: newHandler,
createRPC: createRPC}
2022-11-15 15:29:42 +03:00
}
func (server *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/text; charset=utf-8")
canGzip := strings.Contains(r.Header.Get("Accept-Encoding"), "gzip")
versionHeader := r.Header.Get(server.versionHeader)
cltVersion, cltPlatform := parseVersionHeader(versionHeader)
cltIP := GetIP(r)
server.logger.V(1).Info("New request", "client_ver", cltVersion)
payload, err := ioutil.ReadAll(r.Body)
if err != nil {
2024-04-12 20:21:13 +03:00
payloadLen := len(payload)
server.logger.Error(err, "Error reading body", "client_ver", cltVersion, "client_ip", cltIP, "body_len", payloadLen)
2022-11-15 15:29:42 +03:00
return
}
2022-11-15 17:53:34 +03:00
logger := server.logger
2022-11-15 15:29:42 +03:00
if server.isDev {
reqId := uuid.New()
logger = logger.WithValues("uid", reqId.String())
}
result := server.serveRequest(r.Context(), logger, payload, cltVersion, cltPlatform, cltIP)
server.writeResponse(canGzip, w, result.Payload)
}
func GetIP(r *http.Request) net.IP {
ipStr := GetIPStr(r)
return net.ParseIP(ipStr)
}
func GetIPStr(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ", ")[0]
}
ra, _, _ := net.SplitHostPort(r.RemoteAddr)
return ra
}
func (server *RPCServer) writeResponse(canGzip bool, w http.ResponseWriter, payload []byte) {
if len(payload) == 0 {
server.logger.Error(nil, "Nothing was written to response!")
return
}
var wr io.Writer = w
wrLen := 0
if server.isDev {
wr = &respWriterWithLen{w, &wrLen}
}
if canGzip && len(payload) > server.gzipLenThreshold {
w.Header().Set("Content-Encoding", "gzip")
server.logger.V(1).Info("Response is gzipped")
gz := gzip.NewWriter(wr)
defer gz.Close()
gz.Write(payload)
} else {
wr.Write(payload)
}
if server.isDev {
server.logger.V(1).Info("Response size", "len", wrLen, "orig", len(payload))
}
}
type respWriterWithLen struct {
http.ResponseWriter
Len *int
}
func (w *respWriterWithLen) Write(b []byte) (n int, err error) {
n, err = w.ResponseWriter.Write(b)
(*w.Len) += n
return
}
2022-11-15 17:53:34 +03:00
func (server *RPCServer) serveRequest(context context.Context,
logger logr.Logger,
request []byte,
clientVersion string,
clientPlatform int,
clientIP net.IP) (result RPCResult) {
2022-11-15 15:29:42 +03:00
result = RPCResult{}
defer func() {
if rec := recover(); rec != nil {
stackStr := errors.FormatPanicDebugStack(string(debug.Stack()))
// set error that returns after panic handling. See named return values
panicErr := errors.Errorf("[PANIC] %v %s", rec, stackStr)
// set result that returns after panic handling. See named return values
result.Error = panicErr
result.Payload = getResponseBytes(nil, panicErr)
logger.Error(nil, "RPC panic", "PANIC", rec, "client_ver", clientVersion)
}
}()
2022-11-15 17:53:34 +03:00
h := server.newHandler(&RPCCtx{context, logger, clientVersion, clientIP, clientPlatform, server.isDev})
defer h.Close()
2022-11-15 15:29:42 +03:00
p, err := server.executeRPC(h, request)
result.Error = err
result.Payload = getResponseBytes(p, err)
return
}
func (server *RPCServer) executeRPC(h IRPCHandler, request []byte) (meta.IRPC, error) {
p, err := DecodeRequestToRPC(request, server.createRPC)
if err != nil {
return nil, err
}
if err := h.BeforeExecute(p); err != nil {
2022-11-15 17:53:34 +03:00
err = h.AfterExecute(p, err)
2022-11-15 15:29:42 +03:00
//NOTE: returning constructed packet so that it can be inspected
return p, err
}
err = p.Execute(h)
2022-11-15 17:53:34 +03:00
err = h.AfterExecute(p, err)
2022-11-15 15:29:42 +03:00
if err != nil {
return p, err
}
return p, nil
}
func DecodeRequestToRPC(request []byte, createById meta.RPCFactory) (meta.IRPC, error) {
reader, err := meta.NewMsgpackReader(request)
if err != nil {
return nil, err
}
var code uint32
err = reader.BeginContainer("")
if err != nil {
return nil, err
}
err = reader.ReadU32(&code, "")
if err != nil {
return nil, err
}
rpc, err := createById(code)
if err != nil {
return nil, err
}
req := rpc.GetRequest()
err = req.Read(reader)
if err != nil {
err = errors.Wrapf(err, "Packet code: %d ", code)
return nil, err
}
return rpc, nil
}
func getResponseBytes(rpc meta.IRPC, err error) []byte {
writer := meta.NewMsgpackWriter()
writer.BeginContainer("")
if err == nil {
writer.WriteI32(rpc.GetCode(), "")
rpc.GetResponse().Write(writer)
} else {
var rpcErr *ServerError
if errors.As(err, &rpcErr) {
writer.WriteI32(int32(rpcErr.Code), "")
writer.WriteString(rpcErr.Desc, "")
} else {
writer.WriteI32(GenericErrorCode, "")
writer.WriteString(err.Error(), "")
}
}
writer.EndContainer()
bytes, _ := writer.GetData()
return bytes
}
func parseVersionHeader(header string) (string, int) {
parts := strings.Split(header, ";")
cltVersion := ""
cltPlatformStr := ""
cltPlatform := 0
if len(parts) > 0 {
cltVersion = parts[0]
if len(parts) > 1 {
cltPlatformStr = parts[1]
}
}
cltPlatform, _ = strconv.Atoi(cltPlatformStr)
return cltVersion, cltPlatform
}