commit 80b21c5111d52d51cb44fec14a3e381ffdea012c Author: Pavel Shevaev Date: Tue Nov 15 15:29:42 2022 +0300 first commit diff --git a/server.go b/server.go new file mode 100644 index 0000000..edd824c --- /dev/null +++ b/server.go @@ -0,0 +1,292 @@ +package rpc + +import ( + "compress/gzip" + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "runtime/debug" + "strconv" + "strings" + + "git.bit5.ru/backend/errors" + "git.bit5.ru/backend/meta" + + "github.com/go-logr/logr" + "github.com/google/uuid" +) + +var DevOnlyErr = errors.New("Allowed in dev mode only") + +const GenericErrorCode = -1 + +type ServerError struct { + Code int + Desc string +} + +// implementing error interface +func (e ServerError) Error() string { + return fmt.Sprintf("code: %d, desc: %s", e.Code, e.Desc) +} + +func NewServerError(code int) error { + return &ServerError{Code: code, Desc: ""} +} + +// RPCResult struct to unify response from RPC handlers +type RPCResult struct { + Error error + Payload []byte +} + +type RPCServer struct { + isDev bool + versionHeader string + gzipLenThreshold int + logger logr.Logger + newHandler func(ctx *RPCCtx) IRPCHandler + createRPC func(code uint32)(meta.IRPC, error) +} + +type IRPCHandler interface { + BeforeExecute(p meta.IRPC) error + AfterExecute(p meta.IRPC, err error) error + Close() +} + +type RPCCtx struct { + Ctx context.Context + Log logr.Logger + ClientVersion string + ClientIP net.IP + ClientPlatform int + IsDev bool +} + +func NewRPCServer(logger logr.Logger, + isDev bool, + versionHeader string, + gzipLenThreshold int, + newHandler func(ctx *RPCCtx) IRPCHandler, + createRPC func(code uint32)(meta.IRPC, error)) *RPCServer { + return &RPCServer{ + isDev: isDev, + versionHeader: versionHeader, + gzipLenThreshold : gzipLenThreshold, + logger: logger.WithName("[rpc]"), + newHandler: newHandler, + createRPC: createRPC} +} + +func (server *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/text; charset=utf-8") + + canGzip := strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") + + versionHeader := r.Header.Get(server.versionHeader) + cltVersion, cltPlatform := parseVersionHeader(versionHeader) + cltIP := GetIP(r) + server.logger.V(1).Info("New request", "client_ver", cltVersion) + + payload, err := ioutil.ReadAll(r.Body) + if err != nil { + server.logger.Error(err, "Error reading body", "client_ver", cltVersion) + return + } + + logger := server.logger + if server.isDev { + reqId := uuid.New() + logger = logger.WithValues("uid", reqId.String()) + } + + result := server.serveRequest(r.Context(), logger, payload, cltVersion, cltPlatform, cltIP) + + server.writeResponse(canGzip, w, result.Payload) +} + +func GetIP(r *http.Request) net.IP { + ipStr := GetIPStr(r) + return net.ParseIP(ipStr) +} + +func GetIPStr(r *http.Request) string { + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + return strings.Split(ip, ", ")[0] + } + + ra, _, _ := net.SplitHostPort(r.RemoteAddr) + return ra +} + + +func (server *RPCServer) writeResponse(canGzip bool, w http.ResponseWriter, payload []byte) { + if len(payload) == 0 { + server.logger.Error(nil, "Nothing was written to response!") + return + } + var wr io.Writer = w + wrLen := 0 + if server.isDev { + wr = &respWriterWithLen{w, &wrLen} + } + + if canGzip && len(payload) > server.gzipLenThreshold { + w.Header().Set("Content-Encoding", "gzip") + server.logger.V(1).Info("Response is gzipped") + gz := gzip.NewWriter(wr) + defer gz.Close() + gz.Write(payload) + } else { + wr.Write(payload) + } + + if server.isDev { + server.logger.V(1).Info("Response size", "len", wrLen, "orig", len(payload)) + } +} + +type respWriterWithLen struct { + http.ResponseWriter + Len *int +} + +func (w *respWriterWithLen) Write(b []byte) (n int, err error) { + n, err = w.ResponseWriter.Write(b) + (*w.Len) += n + return +} + +func (server *RPCServer) serveRequest(context context.Context, + logger logr.Logger, + request []byte, + clientVersion string, + clientPlatform int, + clientIP net.IP) (result RPCResult) { + + result = RPCResult{} + + defer func() { + if rec := recover(); rec != nil { + stackStr := errors.FormatPanicDebugStack(string(debug.Stack())) + // set error that returns after panic handling. See named return values + panicErr := errors.Errorf("[PANIC] %v %s", rec, stackStr) + + // set result that returns after panic handling. See named return values + result.Error = panicErr + result.Payload = getResponseBytes(nil, panicErr) + + logger.Error(nil, "RPC panic", "PANIC", rec, "client_ver", clientVersion) + } + }() + + h := server.newHandler(&RPCCtx{context, logger, clientVersion, clientIP, clientPlatform, server.isDev}) + defer h.Close() + + p, err := server.executeRPC(h, request) + + result.Error = err + result.Payload = getResponseBytes(p, err) + + return +} + +func (server *RPCServer) executeRPC(h IRPCHandler, request []byte) (meta.IRPC, error) { + p, err := DecodeRequestToRPC(request, server.createRPC) + if err != nil { + return nil, err + } + + if err := h.BeforeExecute(p); err != nil { + err = h.AfterExecute(p, err) + //NOTE: returning constructed packet so that it can be inspected + return p, err + } + err = p.Execute(h) + err = h.AfterExecute(p, err) + + if err != nil { + return p, err + } + + return p, nil +} + +func DecodeRequestToRPC(request []byte, createById meta.RPCFactory) (meta.IRPC, error) { + reader, err := meta.NewMsgpackReader(request) + if err != nil { + return nil, err + } + var code uint32 + err = reader.BeginContainer("") + if err != nil { + return nil, err + } + err = reader.ReadU32(&code, "") + if err != nil { + return nil, err + } + + rpc, err := createById(code) + if err != nil { + return nil, err + } + + req := rpc.GetRequest() + err = req.Read(reader) + if err != nil { + err = errors.Wrapf(err, "Packet code: %d ", code) + return nil, err + } + return rpc, nil +} + +func getResponseBytes(rpc meta.IRPC, err error) []byte { + writer := meta.NewMsgpackWriter() + writer.BeginContainer("") + + if err == nil { + writer.WriteI32(rpc.GetCode(), "") + rpc.GetResponse().Write(writer) + } else { + var rpcErr *ServerError + if errors.As(err, &rpcErr) { + writer.WriteI32(int32(rpcErr.Code), "") + writer.WriteString(rpcErr.Desc, "") + } else { + writer.WriteI32(GenericErrorCode, "") + writer.WriteString(err.Error(), "") + } + } + + writer.EndContainer() + + bytes, _ := writer.GetData() + return bytes +} + +func parseVersionHeader(header string) (string, int) { + parts := strings.Split(header, ";") + cltVersion := "" + cltPlatformStr := "" + cltPlatform := 0 + if len(parts) > 0 { + cltVersion = parts[0] + if len(parts) > 1 { + cltPlatformStr = parts[1] + } + } + + cltPlatform, _ = strconv.Atoi(cltPlatformStr) + + return cltVersion, cltPlatform +} +