// Package archive streams a directory through tar+zstd (and back). // // Compression level is fixed at zstd level 3 — same as the existing shell // (`zstd -T0 -3`). The encoder is wrapped around an io.Writer (typically an // S3 multipart Uploader's PipeWriter), so the whole pipeline stays // streaming end-to-end with no on-disk tempfile. package archive import ( "archive/tar" "errors" "fmt" "io" "io/fs" "os" "path/filepath" "sort" "strings" "github.com/klauspost/compress/zstd" ) // EncodeDir walks root in sorted relative-path order and writes a // tar.zst stream into w. Symlinks are preserved (stored as tar typelink), // devices/sockets/fifos are skipped. func EncodeDir(w io.Writer, root string) error { zw, err := zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.SpeedDefault)) if err != nil { return fmt.Errorf("zstd writer: %w", err) } defer zw.Close() tw := tar.NewWriter(zw) var paths []string err = filepath.WalkDir(root, func(p string, d fs.DirEntry, err error) error { if err != nil { return err } if p == root { return nil } paths = append(paths, p) return nil }) if err != nil { return fmt.Errorf("walk %s: %w", root, err) } sort.Strings(paths) for _, p := range paths { if err := writeOne(tw, root, p); err != nil { return err } } if err := tw.Close(); err != nil { return fmt.Errorf("close tar: %w", err) } if err := zw.Close(); err != nil { return fmt.Errorf("close zstd: %w", err) } return nil } func writeOne(tw *tar.Writer, root, p string) error { rel, err := filepath.Rel(root, p) if err != nil { return err } rel = filepath.ToSlash(rel) lst, err := os.Lstat(p) if err != nil { return fmt.Errorf("lstat %s: %w", p, err) } var link string if lst.Mode()&os.ModeSymlink != 0 { link, err = os.Readlink(p) if err != nil { return fmt.Errorf("readlink %s: %w", p, err) } } hdr, err := tar.FileInfoHeader(lst, link) if err != nil { // Skip unsupported file types (devices, sockets, fifos). return nil } hdr.Name = rel if lst.IsDir() && !strings.HasSuffix(hdr.Name, "/") { hdr.Name += "/" } if err := tw.WriteHeader(hdr); err != nil { return fmt.Errorf("tar header %s: %w", rel, err) } if !lst.Mode().IsRegular() { return nil } f, err := os.Open(p) if err != nil { return fmt.Errorf("open %s: %w", p, err) } defer f.Close() if _, err := io.Copy(tw, f); err != nil { return fmt.Errorf("copy %s: %w", rel, err) } return nil } // DecodeDir extracts a tar.zst stream from r into dest. dest must exist. // Any entry whose normalized path escapes dest is rejected. func DecodeDir(r io.Reader, dest string) error { absDest, err := filepath.Abs(dest) if err != nil { return fmt.Errorf("abs %s: %w", dest, err) } zr, err := zstd.NewReader(r) if err != nil { return fmt.Errorf("zstd reader: %w", err) } defer zr.Close() tr := tar.NewReader(zr) for { hdr, err := tr.Next() if errors.Is(err, io.EOF) { return nil } if err != nil { return fmt.Errorf("tar next: %w", err) } target, err := safeJoin(absDest, hdr.Name) if err != nil { return err } switch hdr.Typeflag { case tar.TypeDir: if err := os.MkdirAll(target, fs.FileMode(hdr.Mode)&0o7777); err != nil { return fmt.Errorf("mkdir %s: %w", target, err) } case tar.TypeReg, tar.TypeRegA: if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { return fmt.Errorf("mkdir parent of %s: %w", target, err) } f, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fs.FileMode(hdr.Mode)&0o7777) if err != nil { return fmt.Errorf("create %s: %w", target, err) } if _, err := io.Copy(f, tr); err != nil { f.Close() return fmt.Errorf("write %s: %w", target, err) } if err := f.Close(); err != nil { return err } case tar.TypeSymlink: _ = os.Remove(target) if err := os.Symlink(hdr.Linkname, target); err != nil { return fmt.Errorf("symlink %s -> %s: %w", target, hdr.Linkname, err) } default: // Skip unknown entry types silently. } } } // safeJoin returns base/sub, rejecting any sub that is absolute or contains // ".." segments that would resolve outside base. We reject rather than // silently re-root so malicious tarballs surface as an error. func safeJoin(base, sub string) (string, error) { if filepath.IsAbs(sub) || strings.HasPrefix(sub, "/") { return "", fmt.Errorf("absolute tar entry: %q", sub) } clean := filepath.Clean(sub) if clean == ".." || strings.HasPrefix(clean, ".."+string(filepath.Separator)) { return "", fmt.Errorf("tar entry escapes destination: %q", sub) } return filepath.Join(base, clean), nil }