// Package archive streams a directory through tar+zstd (and back).
//
// Compression level is fixed at zstd level 3 — same as the existing shell
// (`zstd -T0 -3`). The encoder is wrapped around an io.Writer (typically an
// S3 multipart Uploader's PipeWriter), so the whole pipeline stays
// streaming end-to-end with no on-disk tempfile.
package archive
import (
"archive/tar"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"github.com/klauspost/compress/zstd"
)
// EncodeDir walks root in sorted relative-path order and writes a
// tar.zst stream into w. Symlinks are preserved (stored as tar typelink),
// devices/sockets/fifos are skipped.
func EncodeDir(w io.Writer, root string) error {
zw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil {
return fmt.Errorf("zstd writer: %w", err)
}
defer zw.Close()
tw := tar.NewWriter(zw)
var paths []string
err = filepath.WalkDir(root, func(p string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if p == root {
return nil
}
paths = append(paths, p)
return nil
})
if err != nil {
return fmt.Errorf("walk %s: %w", root, err)
}
sort.Strings(paths)
for _, p := range paths {
if err := writeOne(tw, root, p); err != nil {
return err
}
}
if err := tw.Close(); err != nil {
return fmt.Errorf("close tar: %w", err)
}
if err := zw.Close(); err != nil {
return fmt.Errorf("close zstd: %w", err)
}
return nil
}
func writeOne(tw *tar.Writer, root, p string) error {
rel, err := filepath.Rel(root, p)
if err != nil {
return err
}
rel = filepath.ToSlash(rel)
lst, err := os.Lstat(p)
if err != nil {
return fmt.Errorf("lstat %s: %w", p, err)
}
var link string
if lst.Mode()&os.ModeSymlink != 0 {
link, err = os.Readlink(p)
if err != nil {
return fmt.Errorf("readlink %s: %w", p, err)
}
}
hdr, err := tar.FileInfoHeader(lst, link)
if err != nil {
// Skip unsupported file types (devices, sockets, fifos).
return nil
}
hdr.Name = rel
if lst.IsDir() && !strings.HasSuffix(hdr.Name, "/") {
hdr.Name += "/"
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("tar header %s: %w", rel, err)
}
if !lst.Mode().IsRegular() {
return nil
}
f, err := os.Open(p)
if err != nil {
return fmt.Errorf("open %s: %w", p, err)
}
defer f.Close()
if _, err := io.Copy(tw, f); err != nil {
return fmt.Errorf("copy %s: %w", rel, err)
}
return nil
}
// DecodeDir extracts a tar.zst stream from r into dest. dest must exist.
// Any entry whose normalized path escapes dest is rejected.
func DecodeDir(r io.Reader, dest string) error {
absDest, err := filepath.Abs(dest)
if err != nil {
return fmt.Errorf("abs %s: %w", dest, err)
}
zr, err := zstd.NewReader(r)
if err != nil {
return fmt.Errorf("zstd reader: %w", err)
}
defer zr.Close()
tr := tar.NewReader(zr)
for {
hdr, err := tr.Next()
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return fmt.Errorf("tar next: %w", err)
}
target, err := safeJoin(absDest, hdr.Name)
if err != nil {
return err
}
switch hdr.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, fs.FileMode(hdr.Mode)&0o7777); err != nil {
return fmt.Errorf("mkdir %s: %w", target, err)
}
case tar.TypeReg, tar.TypeRegA:
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
return fmt.Errorf("mkdir parent of %s: %w", target, err)
}
f, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fs.FileMode(hdr.Mode)&0o7777)
if err != nil {
return fmt.Errorf("create %s: %w", target, err)
}
if _, err := io.Copy(f, tr); err != nil {
f.Close()
return fmt.Errorf("write %s: %w", target, err)
}
if err := f.Close(); err != nil {
return err
}
case tar.TypeSymlink:
_ = os.Remove(target)
if err := os.Symlink(hdr.Linkname, target); err != nil {
return fmt.Errorf("symlink %s -> %s: %w", target, hdr.Linkname, err)
}
default:
// Skip unknown entry types silently.
}
}
}
// safeJoin returns base/sub, rejecting any sub that is absolute or contains
// ".." segments that would resolve outside base. We reject rather than
// silently re-root so malicious tarballs surface as an error.
func safeJoin(base, sub string) (string, error) {
if filepath.IsAbs(sub) || strings.HasPrefix(sub, "/") {
return "", fmt.Errorf("absolute tar entry: %q", sub)
}
clean := filepath.Clean(sub)
if clean == ".." || strings.HasPrefix(clean, ".."+string(filepath.Separator)) {
return "", fmt.Errorf("tar entry escapes destination: %q", sub)
}
return filepath.Join(base, clean), nil
}