tableflip/upgrader_test.go

498 lines
8.7 KiB
Go

package tableflip
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
"testing"
"time"
)
type testUpgrader struct {
*Upgrader
procs chan *testProcess
}
func newTestUpgrader(opts Options) *testUpgrader {
env, procs := testEnv()
u, err := newUpgrader(env, opts)
if err != nil {
panic(err)
}
err = u.Ready()
if err != nil {
panic(err)
}
return &testUpgrader{
Upgrader: u,
procs: procs,
}
}
func (tu *testUpgrader) upgradeProc(t *testing.T) (*testProcess, <-chan error) {
t.Helper()
ch := make(chan error, 1)
go func() {
for {
err := tu.Upgrade()
if err != errNotReady {
ch <- err
return
}
}
}()
select {
case err := <-ch:
t.Fatal("Upgrade failed:", err)
return nil, nil
case proc := <-tu.procs:
return proc, ch
}
}
var names = []string{"zaphod", "beeblebrox"}
func TestMain(m *testing.M) {
upg, err := New(Options{})
if err != nil {
panic(err)
}
if upg.parent == nil {
// Execute test suite if there is no parent.
os.Exit(m.Run())
}
pid, err := upg.Fds.File("pid")
if err != nil {
panic(err)
}
if pid != nil {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, uint64(os.Getpid()))
pid.Write(buf)
pid.Close()
}
parent, err := upg.Fds.File("hasParent")
if err != nil {
panic(err)
}
if parent != nil {
if _, err := io.WriteString(parent, fmt.Sprint(upg.HasParent())); err != nil {
panic(err)
}
parent.Close()
}
for _, name := range names {
file, err := upg.Fds.File(name)
if err != nil {
panic(err)
}
if file == nil {
continue
}
if _, err := io.WriteString(file, name); err != nil {
panic(err)
}
}
if err := upg.Ready(); err != nil {
panic(err)
}
}
func TestUpgraderOnOS(t *testing.T) {
u, err := newUpgrader(stdEnv, Options{})
if err != nil {
t.Fatal("Can't create Upgrader:", err)
}
defer u.Stop()
rPid, wPid, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
defer rPid.Close()
if err := u.Fds.AddFile("pid", wPid); err != nil {
t.Fatal(err)
}
wPid.Close()
rHasParent, wHasParent, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
defer rHasParent.Close()
if err := u.Fds.AddFile("hasParent", wHasParent); err != nil {
t.Fatal(err)
}
wHasParent.Close()
var readers []*os.File
defer func() {
for _, r := range readers {
r.Close()
}
}()
for _, name := range names {
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
readers = append(readers, r)
if err := u.Fds.AddFile(name, w); err != nil {
t.Fatal(err)
}
w.Close()
}
if err := u.Upgrade(); err == nil {
t.Error("Upgrade before Ready should return an error")
}
if err := u.Ready(); err != nil {
t.Fatal("Ready failed:", err)
}
for {
if err := u.Upgrade(); err == nil {
break
} else if err != errNotReady {
t.Fatal("Upgrade failed:", err)
}
}
// Close copies of write pipes, so that
// reads below return EOF.
u.Stop()
buf := make([]byte, 8)
if _, err := rPid.Read(buf); err != nil {
t.Fatal(err)
}
if int(binary.LittleEndian.Uint64(buf)) == os.Getpid() {
t.Error("Child did not execute in new process")
}
hasParentBytes, err := ioutil.ReadAll(rHasParent)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(hasParentBytes, []byte("true")) {
t.Fatal("Child did not recognize parent")
}
for i, name := range names {
nameBytes, err := ioutil.ReadAll(readers[i])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(nameBytes, []byte(name)) {
t.Fatalf("File %s has name %s in child", name, string(nameBytes))
}
}
}
func TestUpgraderCleanExit(t *testing.T) {
t.Parallel()
u := newTestUpgrader(Options{})
defer u.Stop()
proc, errs := u.upgradeProc(t)
proc.exit(nil)
if err := <-errs; err == nil {
t.Error("Expected Upgrade to return error when new child exits clean")
}
}
func TestUpgraderUncleanExit(t *testing.T) {
t.Parallel()
u := newTestUpgrader(Options{})
defer u.Stop()
proc, errs := u.upgradeProc(t)
proc.exit(errors.New("some error"))
if err := <-errs; err == nil {
t.Error("Expected Upgrade to return error when new child exits unclean")
}
}
func TestUpgraderTimeout(t *testing.T) {
t.Parallel()
u := newTestUpgrader(Options{
UpgradeTimeout: 10 * time.Millisecond,
})
defer u.Stop()
new, errs := u.upgradeProc(t)
if sig := new.recvSignal(nil); sig != os.Kill {
t.Error("Expected os.Kill, got", sig)
}
if err := <-errs; err == nil {
t.Error("Expected Upgrade to return error when new child times out")
}
}
func TestUpgraderConcurrentUpgrade(t *testing.T) {
t.Parallel()
u := newTestUpgrader(Options{})
defer u.Stop()
new, _ := u.upgradeProc(t)
go new.recvSignal(nil)
if err := u.Upgrade(); err == nil {
t.Error("Expected Upgrade to refuse concurrent upgrade")
}
new.exit(nil)
}
func TestHasParent(t *testing.T) {
t.Parallel()
u := newTestUpgrader(Options{})
defer u.Stop()
if u.HasParent() {
t.Fatal("First process cannot have a parent")
}
}
func TestUpgraderWaitForParent(t *testing.T) {
t.Parallel()
env, procs := testEnv()
child, err := startChild(env, nil)
if err != nil {
t.Fatal(err)
}
proc := <-procs
u, err := newUpgrader(&proc.env, Options{})
if err != nil {
t.Fatal(err)
}
defer u.Stop()
if err := u.Ready(); err != nil {
t.Fatal(err)
}
exited := make(chan error, 1)
go func() {
exited <- u.WaitForParent(context.Background())
}()
select {
case <-exited:
t.Fatal("Returned before parent exited")
case <-time.After(time.Second):
}
readyFile := <-child.ready
if err := readyFile.Close(); err != nil {
t.Fatal(err)
}
if err := <-exited; err != nil {
t.Fatal("Unexpected error:", err)
}
}
func TestUpgraderReady(t *testing.T) {
t.Parallel()
u := newTestUpgrader(Options{})
defer u.Stop()
new, errs := u.upgradeProc(t)
_, exited, err := new.notify()
if err != nil {
t.Fatal("Can't notify Upgrader:", err)
}
if err := <-errs; err != nil {
t.Fatal("Expected Upgrade to return nil when child is ready")
}
select {
case <-u.Exit():
default:
t.Error("Expected Exit() to be closed when upgrade is done")
}
// Simulate the process exiting
file := <-u.exitFd
file.file.Close()
select {
case err := <-exited:
if err != nil {
t.Error("exit error", err)
}
case <-time.After(time.Second):
t.Error("Child wasn't notified of parent exiting")
}
}
func TestUpgraderShutdownCancelsUpgrade(t *testing.T) {
t.Parallel()
u := newTestUpgrader(Options{})
defer u.Stop()
new, errs := u.upgradeProc(t)
go new.recvSignal(nil)
u.Stop()
if err := <-errs; err == nil {
t.Error("Upgrade doesn't return an error when Stopp()ed")
}
if err := u.Upgrade(); err == nil {
t.Error("Upgrade doesn't return an error after Stop()")
}
}
func TestReadyWritesPIDFile(t *testing.T) {
t.Parallel()
dir, err := ioutil.TempDir("", "tableflip")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
file := dir + "/pid"
u := newTestUpgrader(Options{
PIDFile: file,
})
defer u.Stop()
if err := u.Ready(); err != nil {
t.Fatal("Ready returned error:", err)
}
fh, err := os.Open(file)
if err != nil {
t.Fatal("PID file doesn't exist:", err)
}
defer fh.Close()
var pid int
if _, err := fmt.Fscan(fh, &pid); err != nil {
t.Fatal("Can't read PID:", err)
}
if pid != os.Getpid() {
t.Error("PID doesn't match")
}
}
func TestWritePidFileWithoutPath(t *testing.T) {
pidFile := "tableflip-test.pid"
err := writePIDFile(pidFile)
if err != nil {
t.Fatal("Could not write pidfile:", err)
}
defer os.Remove(pidFile)
// lets see if we are able to read the file back
fh, err := os.Open(pidFile)
if err != nil {
t.Fatal("PID file doesn't exist:", err)
}
defer fh.Close()
// just to be sure: check the pid for correctness
// if something failed at a previous run we could be reading
// a bogus pidfile
var pid int
if _, err := fmt.Fscan(fh, &pid); err != nil {
t.Fatal("Can't read PID:", err)
}
if pid != os.Getpid() {
t.Error("PID doesn't match")
}
}
func BenchmarkUpgrade(b *testing.B) {
for _, n := range []int{4, 400, 4000} {
b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) {
fds := newFds(nil)
for i := 0; i < n; i += 2 {
r, w, err := os.Pipe()
if err != nil {
b.Fatal(err)
}
err = fds.AddFile(strconv.Itoa(n), r)
if err != nil {
b.Fatal(err)
}
r.Close()
err = fds.AddFile(strconv.Itoa(n), w)
if err != nil {
b.Fatal(err)
}
w.Close()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
u, err := newUpgrader(stdEnv, Options{})
if err != nil {
b.Fatal("Can't create Upgrader:", err)
}
if err := u.Ready(); err != nil {
b.Fatal("Can't call Ready:", err)
}
u.Fds = fds
if err := u.Upgrade(); err != nil {
b.Fatal(err)
}
}
b.StopTimer()
for _, f := range fds.used {
f.Close()
}
})
}
}