tableflip/upgrader.go

308 lines
6.9 KiB
Go

package tableflip
import (
"context"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"sync"
"time"
"github.com/pkg/errors"
)
// DefaultUpgradeTimeout is the duration before the Upgrader kills the new process if no
// readiness notification was received.
const DefaultUpgradeTimeout time.Duration = time.Minute
// Options control the behaviour of the Upgrader.
type Options struct {
// Time after which an upgrade is considered failed. Defaults to
// DefaultUpgradeTimeout.
UpgradeTimeout time.Duration
// The PID of a ready process is written to this file.
PIDFile string
WDir string
BinFile string
}
// Upgrader handles zero downtime upgrades and passing files between processes.
type Upgrader struct {
*Fds
*env
opts Options
parent *parent
parentErr chan error
readyOnce sync.Once
readyC chan struct{}
stopOnce sync.Once
stopC chan struct{}
upgradeC chan chan<- error
exitC chan struct{}
exitFd chan neverCloseThisFile
}
var (
stdEnvMu sync.Mutex
stdEnvUpgrader *Upgrader
)
// New creates a new Upgrader. Files are passed from the parent and may be empty.
//
// Only the first call to this function will succeed.
func New(opts Options) (upg *Upgrader, err error) {
stdEnvMu.Lock()
defer stdEnvMu.Unlock()
if stdEnvUpgrader != nil {
return nil, errors.New("tableflip: only a single Upgrader allowed")
}
upg, err = newUpgrader(stdEnv, opts)
// Store a reference to upg in a private global variable, to prevent
// it from being GC'ed and exitFd being closed prematurely.
stdEnvUpgrader = upg
return
}
func newUpgrader(env *env, opts Options) (*Upgrader, error) {
if initialWD == "" {
return nil, errors.New("couldn't determine initial working directory")
}
parent, files, err := newParent(env)
if err != nil {
return nil, err
}
if opts.UpgradeTimeout <= 0 {
opts.UpgradeTimeout = DefaultUpgradeTimeout
}
u := &Upgrader{
env: env,
opts: opts,
parent: parent,
parentErr: make(chan error, 1),
readyC: make(chan struct{}),
stopC: make(chan struct{}),
upgradeC: make(chan chan<- error),
exitC: make(chan struct{}),
exitFd: make(chan neverCloseThisFile, 1),
Fds: newFds(files),
}
go u.run(opts.WDir, opts.BinFile)
return u, nil
}
// Ready signals that the current process is ready to accept connections.
// It must be called to finish the upgrade.
//
// All fds which were inherited but not used are closed after the call to Ready.
func (u *Upgrader) Ready() error {
u.readyOnce.Do(func() {
u.Fds.closeInherited()
close(u.readyC)
})
if u.opts.PIDFile != "" {
if err := writePIDFile(u.opts.PIDFile); err != nil {
return errors.Wrap(err, "tableflip: can't write PID file")
}
}
if u.parent == nil {
return nil
}
return u.parent.sendReady()
}
// Exit returns a channel which is closed when the process should
// exit.
func (u *Upgrader) Exit() <-chan struct{} {
return u.exitC
}
// Stop prevents any more upgrades from happening, and closes
// the exit channel.
//
// If this function is called before a call to Upgrade() has
// succeeded, it is assumed that the process is being shut down
// completely. All Unix sockets known to Upgrader.Fds are then
// unlinked from the filesystem.
func (u *Upgrader) Stop() {
u.stopOnce.Do(func() {
// Interrupt any running Upgrade(), and
// prevent new upgrade from happening.
close(u.stopC)
})
}
// WaitForParent blocks until the parent has exited.
//
// Returns an error if the parent misbehaved during shutdown.
func (u *Upgrader) WaitForParent(ctx context.Context) error {
if u.parent == nil {
return nil
}
var err error
select {
case err = <-u.parent.result:
case err = <-u.parentErr:
case <-ctx.Done():
return ctx.Err()
}
// This is a bit cheeky, since it means that multiple
// calls to WaitForParent resolve in sequence, but that
// probably doesn't matter.
u.parentErr <- err
return err
}
// HasParent checks if the current process is an upgrade or the first invocation.
func (u *Upgrader) HasParent() bool {
return u.parent != nil
}
// Upgrade triggers an upgrade.
func (u *Upgrader) Upgrade() error {
response := make(chan error, 1)
select {
case <-u.stopC:
return errors.New("terminating")
case <-u.exitC:
return errors.New("already upgraded")
case u.upgradeC <- response:
}
return <-response
}
var errNotReady = errors.New("process is not ready yet")
func (u *Upgrader) run(wdir string, binpath string) {
defer close(u.exitC)
var (
parentExited <-chan struct{}
processReady = u.readyC
)
if u.parent != nil {
parentExited = u.parent.exited
}
for {
select {
case <-parentExited:
parentExited = nil
case <-processReady:
processReady = nil
case <-u.stopC:
u.Fds.closeAndRemoveUsed()
return
case request := <-u.upgradeC:
if processReady != nil {
request <- errNotReady
continue
}
if parentExited != nil {
request <- errors.New("parent hasn't exited")
continue
}
file, err := u.doUpgrade(wdir, binpath)
request <- err
if err == nil {
// Save file in exitFd, so that it's only closed when the process
// exits. This signals to the new process that the old process
// has exited.
u.exitFd <- neverCloseThisFile{file}
u.Fds.closeUsed()
return
}
}
}
}
func (u *Upgrader) doUpgrade(wdir string, binpath string) (*os.File, error) {
child, err := startChild(wdir, binpath, u.env, u.Fds.copy())
if err != nil {
return nil, errors.Wrap(err, "can't start child")
}
readyTimeout := time.After(u.opts.UpgradeTimeout)
for {
select {
case request := <-u.upgradeC:
request <- errors.New("upgrade in progress")
case err := <-child.result:
if err == nil {
return nil, errors.Errorf("child %s exited", child)
}
return nil, errors.Wrapf(err, "child %s exited", child)
case <-u.stopC:
child.Kill()
return nil, errors.New("terminating")
case <-readyTimeout:
child.Kill()
return nil, errors.Errorf("new child %s timed out", child)
case file := <-child.ready:
return file, nil
}
}
}
// This file must never be closed by the Go runtime, since its used by the
// child to determine when the parent has died. It must only be closed
// by the OS.
// Hence we make sure that this file can't be garbage collected by referencing
// it from an Upgrader.
type neverCloseThisFile struct {
file *os.File
}
func writePIDFile(path string) error {
dir, file := filepath.Split(path)
// if dir is empty, the user probably specified just the name
// of the pid file expecting it to be created in the current work directory
if dir == "" {
dir = initialWD
}
if dir == "" {
return errors.New("empty initial working directory")
}
fh, err := ioutil.TempFile(dir, file)
if err != nil {
return err
}
defer fh.Close()
// Remove temporary PID file if something fails
defer os.Remove(fh.Name())
_, err = fh.WriteString(strconv.Itoa(os.Getpid()))
if err != nil {
return err
}
return os.Rename(fh.Name(), path)
}