#!/usr/bin/env python3

# Copyright (c) 2016, Antonio SJ Musumeci <trapexit@spawn.link>
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import argparse
import ctypes
import errno
import fnmatch
import io
import os
import shlex
import subprocess
import sys


_libc = ctypes.CDLL("libc.so.6",use_errno=True)
_lgetxattr = _libc.lgetxattr
_lgetxattr.argtypes = [ctypes.c_char_p,ctypes.c_char_p,ctypes.c_void_p,ctypes.c_size_t]
def lgetxattr(path,name):
    if type(path) == str:
        path = path.encode(errors='backslashreplace')
    if type(name) == str:
        name = name.encode(errors='backslashreplace')
    length = 64
    while True:
        buf = ctypes.create_string_buffer(length)
        res = _lgetxattr(path,name,buf,ctypes.c_size_t(length))
        if res >= 0:
            return buf.raw[0:res].decode(errors='backslashreplace')
        else:
            err = ctypes.get_errno()
            if err == errno.ERANGE:
                length *= 2
            elif err == errno.ENODATA:
                return None
            else:
                raise IOError(err,os.strerror(err),path)


def ismergerfs(path):
    try:
        lgetxattr(path,'user.mergerfs.basepath')
        return True
    except IOError as e:
        return False


def mergerfs_control_file(basedir):
    if basedir == '/':
        return None
    ctrlfile = os.path.join(basedir,'.mergerfs')
    if os.path.exists(ctrlfile):
        return ctrlfile
    basedir = os.path.dirname(basedir)
    return mergerfs_control_file(basedir)


def mergerfs_branches(ctrlfile):
    branches = lgetxattr(ctrlfile,'user.mergerfs.srcmounts')
    branches = branches.split(':')
    return branches


def match(filename,matches):
    for match in matches:
        if fnmatch.fnmatch(filename,match):
            return True
    return False


def execute_cmd(args):
    return subprocess.call(args)


def print_args(args):
    quoted = [shlex.quote(arg) for arg in args]
    print(' '.join(quoted))


def build_copy_file(src,tgt,rel):
    srcpath = os.path.join(src,'./',rel)
    tgtpath = tgt + '/'
    return ['rsync',
            '-avHAXWE',
            '--numeric-ids',
            '--progress',
            '--relative',
            srcpath,
            tgtpath]


def build_branches_freespace(branches):
    rv = dict()
    for branch in branches:
        st = os.statvfs(branch)
        rv[branch] = st.f_bavail * st.f_frsize
    return rv


def print_help():
    help = \
'''
usage: mergerfs.dup [<options>] <dir>

Duplicate files & directories across multiple drives in a pool.
Will print out commands for inspection and out of band use.

positional arguments:
  dir                    starting directory

optional arguments:
  -c, --count=           Number of copies to create. (default: 2)
  -d, --dup=             Which file (if more than one exists) to choose to
                         duplicate. Each one falls back to `mergerfs` if
                         all files have the same value. (default: newest)
                         * newest   : file with largest mtime
                         * oldest   : file with smallest mtime
                         * smallest : file with smallest size
                         * largest  : file with largest size
                         * mergerfs : file chosen by mergerfs' getattr
  -p, --prune            Remove files above `count`. Without this enabled
                         it will update all existing files.
  -e, --execute          Execute `rsync` and `rm` commands. Not just
                         print them.
  -I, --include=         fnmatch compatible filter to include files.
                         Can be used multiple times.
  -E, --exclude=         fnmatch compatible filter to exclude files.
                         Can be used multiple times.
'''
    print(help)


def buildargparser():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('dir',
                        type=str,
                        nargs='?',
                        default=None)
    parser.add_argument('-c','--count',
                        dest='count',
                        type=int,
                        default=2)
    parser.add_argument('-p','--prune',
                        dest='prune',
                        action='store_true')
    parser.add_argument('-d','--dup',
                        choices=['newest','oldest',
                                 'smallest','largest',
                                 'mergerfs'],
                        default='newest')
    parser.add_argument('-e','--execute',
                        dest='execute',
                        action='store_true')
    parser.add_argument('-I','--include',
                        dest='include',
                        type=str,
                        action='append',
                        default=[])
    parser.add_argument('-E','--exclude',
                        dest='exclude',
                        type=str,
                        action='append',
                        default=[])
    parser.add_argument('-h','--help',
                        action='store_true')

    return parser


def xattr_basepath(fullpath):
    return lgetxattr(fullpath,'user.mergerfs.basepath')


def xattr_allpaths(fullpath):
    return lgetxattr(fullpath,'user.mergerfs.allpaths')


def xattr_relpath(fullpath):
    return lgetxattr(fullpath,'user.mergerfs.relpath')


def exists(base,rel,name):
    fullpath = os.path.join(base,rel,name)
    return os.path.lexists(fullpath)


def mergerfs_all_basepaths(fullpath,relpath):
    attr = xattr_allpaths(fullpath)
    if not attr:
        dirname  = os.path.dirname(fullpath)
        basename = os.path.basename(fullpath)
        attr     = xattr_allpaths(dirname)
        attr     = attr.split('\0')
        attr     = [os.path.join(path,basename)
                    for path in attr
                    if os.path.lexists(os.path.join(path,basename))]
    else:
        attr = attr.split('\0')
    return [x[:-len(relpath)].rstrip('/') for x in attr]


def mergerfs_basepath(fullpath):
    attr = xattr_basepath(fullpath)
    if not attr:
        dirname  = os.path.dirname(fullpath)
        basename = os.path.basename(fullpath)
        attr     = xattr_allpaths(dirname)
        attr     = attr.split('\0')
        for path in attr:
            fullpath = os.path.join(path,basename)
            if os.path.lexists(fullpath):
                relpath = xattr_relpath(dirname)
                return path[:-len(relpath)].rstrip('/')
    return attr


def mergerfs_relpath(fullpath):
    attr = xattr_relpath(fullpath)
    if not attr:
        dirname  = os.path.dirname(fullpath)
        basename = os.path.basename(fullpath)
        attr     = xattr_relpath(dirname)
        attr     = os.path.join(attr,basename)
    return attr.lstrip('/')


def newest_dupfun(default_basepath,relpath,basepaths):
    sts = dict([(f,os.lstat(os.path.join(f,relpath))) for f in basepaths])

    mtime = sts[basepaths[0]].st_mtime
    if not all([st.st_mtime == mtime for st in sts.values()]):
        return sorted(sts,key=lambda x: sts.get(x).st_mtime,reverse=True)[0]

    ctime = sts[basepaths[0]].st_ctime
    if not all([st.st_ctime == ctime for st in sts.values()]):
        return sorted(sts,key=lambda x: sts.get(x).st_ctime,reverse=True)[0]

    return default_basepath


def oldest_dupfun(default_basepath,relpath,basepaths):
    sts = dict([(f,os.lstat(os.path.join(f,relpath))) for f in basepaths])

    mtime = sts[basepaths[0]].st_mtime
    if not all([st.st_mtime == mtime for st in sts.values()]):
        return sorted(sts,key=lambda x: sts.get(x).st_mtime,reverse=False)[0]

    ctime = sts[basepaths[0]].st_ctime
    if not all([st.st_ctime == ctime for st in sts.values()]):
        return sorted(sts,key=lambda x: sts.get(x).st_ctime,reverse=False)[0]

    return default_basepath


def largest_dupfun(default_basepath,relpath,basepaths):
    sts = dict([(f,os.lstat(os.path.join(f,relpath))) for f in basepaths])

    size = sts[basepaths[0]].st_size
    if not all([st.st_size == size for st in sts.values()]):
        return sorted(sts,key=lambda x: sts.get(x).st_size,reverse=True)[0]

    return default_basepath


def smallest_dupfun(default_basepath,relpath,basepaths):
    sts = dict([(f,os.lstat(os.path.join(f,relpath))) for f in basepaths])

    size = sts[basepaths[0]].st_size
    if not all([st.st_size == size for st in sts.values()]):
        return sorted(sts,key=lambda x: sts.get(x).st_size,reverse=False)[0]

    return default_basepath


def mergerfs_dupfun(default_basepath,relpath,basepaths):
    return default_basepath


def getdupfun(name):
    funs = {'newest': newest_dupfun,
            'oldest': oldest_dupfun,
            'smallest': smallest_dupfun,
            'largest': largest_dupfun,
            'mergerfs': mergerfs_dupfun}
    return funs[name]


def main():
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer,
                                  encoding='utf8',
                                  errors='backslashreplace',
                                  line_buffering=True)
    sys.stderr = io.TextIOWrapper(sys.stderr.buffer,
                                  encoding='utf8',
                                  errors='backslashreplace',
                                  line_buffering=True)

    parser = buildargparser()
    args = parser.parse_args()

    if args.help or not args.dir:
        print_help()
        sys.exit(0)

    args.dir = os.path.realpath(args.dir)

    if not ismergerfs(args.dir):
        print("%s is not a mergerfs mount" % args.dir)
        sys.exit(1)

    prune     = args.prune
    execute   = args.execute
    includes  = ['*'] if not args.include else args.include
    excludes  = args.exclude
    dupfun    = getdupfun(args.dup)
    ctrlfile  = mergerfs_control_file(args.dir)
    branches  = mergerfs_branches(ctrlfile)
    branches  = build_branches_freespace(branches)
    count     = min(args.count,len(branches))

    try:
        for (dirpath,dirnames,filenames) in os.walk(args.dir):
            for filename in filenames:
                if match(filename,excludes):
                    continue
                if not match(filename,includes):
                    continue

                fullpath = os.path.join(dirpath,filename)
                basepath = mergerfs_basepath(fullpath)
                relpath  = mergerfs_relpath(fullpath)
                existing = mergerfs_all_basepaths(fullpath,relpath)

                srcpath  = dupfun(basepath,relpath,existing)
                srcfile  = os.path.join(srcpath,relpath)
                srcfile_size = os.lstat(srcfile).st_size
                existing.remove(srcpath)

                i = 1
                copies = []
                for tgtpath in existing:
                    if prune and i >= count:
                        break
                    copies.append(tgtpath)
                    args = build_copy_file(srcpath,tgtpath,relpath)
                    print('# overwrite')
                    print_args(args)
                    if execute:
                        execute_cmd(args)
                    i += 1

                for _ in range(i,count):
                    for branch in sorted(branches,key=branches.get,reverse=True):
                        tgtfile = os.path.join(branch,relpath)
                        if branch in copies or os.path.exists(tgtfile):
                            continue
                        copies.append(branch)
                        branches[branch] -= srcfile_size
                        args = build_copy_file(srcpath,branch,relpath)
                        print('# copy')
                        print_args(args)
                        if execute:
                            execute_cmd(args)
                        break

                if prune:
                    leftovers = set(existing) - set(copies)
                    for branch in leftovers:
                        branches[branch] += srcfile_size
                        tgtfile = os.path.join(branch,relpath)
                        print('# remove')
                        args = ['rm','-vf',tgtfile]
                        print_args(args)
                        if execute:
                            execute_cmd(args)


    except KeyboardInterrupt:
        print("exiting: CTRL-C pressed")
    except BrokenPipeError:
        pass

    sys.exit(0)


if __name__ == "__main__":
   main()