package rpc import ( "compress/gzip" "context" "fmt" "io" "io/ioutil" "net" "net/http" "runtime/debug" "strconv" "strings" "git.bit5.ru/backend/errors" "git.bit5.ru/backend/meta" "github.com/go-logr/logr" "github.com/google/uuid" ) var DevOnlyErr = errors.New("Allowed in dev mode only") const GenericErrorCode = -1 type ServerError struct { Code int Desc string } // implementing error interface func (e ServerError) Error() string { return fmt.Sprintf("code: %d, desc: %s", e.Code, e.Desc) } func NewServerError(code int) error { return &ServerError{Code: code, Desc: ""} } // RPCResult struct to unify response from RPC handlers type RPCResult struct { Error error Payload []byte } type RPCServer struct { isDev bool versionHeader string gzipLenThreshold int logger logr.Logger newHandler func(ctx *RPCCtx) IRPCHandler createRPC func(code uint32) (meta.IRPC, error) } type IRPCHandler interface { BeforeExecute(p meta.IRPC) error AfterExecute(p meta.IRPC, err error) error Close() } type RPCCtx struct { Ctx context.Context Log logr.Logger ClientVersion string ClientIP net.IP ClientPlatform int IsDev bool } func NewRPCServer(logger logr.Logger, isDev bool, versionHeader string, gzipLenThreshold int, newHandler func(ctx *RPCCtx) IRPCHandler, createRPC func(code uint32) (meta.IRPC, error)) *RPCServer { return &RPCServer{ isDev: isDev, versionHeader: versionHeader, gzipLenThreshold: gzipLenThreshold, logger: logger.WithName("[rpc]"), newHandler: newHandler, createRPC: createRPC} } func (server *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/text; charset=utf-8") canGzip := strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") versionHeader := r.Header.Get(server.versionHeader) cltVersion, cltPlatform := parseVersionHeader(versionHeader) cltIP := GetIP(r) server.logger.V(1).Info("New request", "client_ver", cltVersion) payload, err := ioutil.ReadAll(r.Body) if err != nil { payloadLen := len(payload) server.logger.Error(err, "Error reading body", "client_ver", cltVersion, "client_platform", cltPlatform, "client_ip", cltIP, "body_len", payloadLen) return } logger := server.logger if server.isDev { reqId := uuid.New() logger = logger.WithValues("uid", reqId.String()) } result := server.serveRequest(r.Context(), logger, payload, cltVersion, cltPlatform, cltIP) server.writeResponse(canGzip, w, result.Payload) } func GetIP(r *http.Request) net.IP { ipStr := GetIPStr(r) return net.ParseIP(ipStr) } func GetIPStr(r *http.Request) string { if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } if ip := r.Header.Get("X-Forwarded-For"); ip != "" { return strings.Split(ip, ", ")[0] } ra, _, _ := net.SplitHostPort(r.RemoteAddr) return ra } func (server *RPCServer) writeResponse(canGzip bool, w http.ResponseWriter, payload []byte) { if len(payload) == 0 { server.logger.Error(nil, "Nothing was written to response!") return } var wr io.Writer = w wrLen := 0 if server.isDev { wr = &respWriterWithLen{w, &wrLen} } if canGzip && len(payload) > server.gzipLenThreshold { w.Header().Set("Content-Encoding", "gzip") server.logger.V(1).Info("Response is gzipped") gz := gzip.NewWriter(wr) defer gz.Close() gz.Write(payload) } else { wr.Write(payload) } if server.isDev { server.logger.V(1).Info("Response size", "len", wrLen, "orig", len(payload)) } } type respWriterWithLen struct { http.ResponseWriter Len *int } func (w *respWriterWithLen) Write(b []byte) (n int, err error) { n, err = w.ResponseWriter.Write(b) (*w.Len) += n return } func (server *RPCServer) serveRequest(context context.Context, logger logr.Logger, request []byte, clientVersion string, clientPlatform int, clientIP net.IP) (result RPCResult) { result = RPCResult{} defer func() { if rec := recover(); rec != nil { stackStr := errors.FormatPanicDebugStack(string(debug.Stack())) // set error that returns after panic handling. See named return values panicErr := errors.Errorf("[PANIC] %v %s", rec, stackStr) // set result that returns after panic handling. See named return values result.Error = panicErr result.Payload = getResponseBytes(nil, panicErr) logger.Error(nil, "RPC panic", "PANIC", rec, "client_ver", clientVersion) } }() h := server.newHandler(&RPCCtx{context, logger, clientVersion, clientIP, clientPlatform, server.isDev}) defer h.Close() p, err := server.executeRPC(h, request) result.Error = err result.Payload = getResponseBytes(p, err) return } func (server *RPCServer) executeRPC(h IRPCHandler, request []byte) (meta.IRPC, error) { p, err := DecodeRequestToRPC(request, server.createRPC) if err != nil { return nil, err } if err := h.BeforeExecute(p); err != nil { err = h.AfterExecute(p, err) //NOTE: returning constructed packet so that it can be inspected return p, err } err = p.Execute(h) err = h.AfterExecute(p, err) if err != nil { return p, err } return p, nil } func DecodeRequestToRPC(request []byte, createById meta.RPCFactory) (meta.IRPC, error) { reader, err := meta.NewMsgpackReader(request) if err != nil { return nil, err } var code uint32 err = reader.BeginContainer("") if err != nil { return nil, err } err = reader.ReadU32(&code, "") if err != nil { return nil, err } rpc, err := createById(code) if err != nil { return nil, err } req := rpc.GetRequest() err = req.Read(reader) if err != nil { err = errors.Wrapf(err, "Packet code: %d ", code) return nil, err } return rpc, nil } func getResponseBytes(rpc meta.IRPC, err error) []byte { writer := meta.NewMsgpackWriter() writer.BeginContainer("") if err == nil { writer.WriteI32(rpc.GetCode(), "") rpc.GetResponse().Write(writer) } else { var rpcErr *ServerError if errors.As(err, &rpcErr) { writer.WriteI32(int32(rpcErr.Code), "") writer.WriteString(rpcErr.Desc, "") } else { writer.WriteI32(GenericErrorCode, "") writer.WriteString(err.Error(), "") } } writer.EndContainer() bytes, _ := writer.GetData() return bytes } func parseVersionHeader(header string) (string, int) { parts := strings.Split(header, ";") cltVersion := "" cltPlatformStr := "" cltPlatform := 0 if len(parts) > 0 { cltVersion = parts[0] if len(parts) > 1 { cltPlatformStr = parts[1] } } cltPlatform, _ = strconv.Atoi(cltPlatformStr) return cltVersion, cltPlatform }