#!/usr/bin/env python3

# Copyright (c) 2016, Antonio SJ Musumeci <trapexit@spawn.link>

# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.

# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import argparse
import ctypes
import errno
import io
import os
import sys


_libc = ctypes.CDLL("libc.so.6",use_errno=True)
_lgetxattr = _libc.lgetxattr
_lgetxattr.argtypes = [ctypes.c_char_p,ctypes.c_char_p,ctypes.c_void_p,ctypes.c_size_t]
def lgetxattr(path,name):
    if type(path) == str:
        path = path.encode(errors='backslashreplace')
    if type(name) == str:
        name = name.encode(errors='backslashreplace')
    length = 64
    while True:
        buf = ctypes.create_string_buffer(length)
        res = _lgetxattr(path,name,buf,ctypes.c_size_t(length))
        if res >= 0:
            return buf.raw[0:res]
        else:
            err = ctypes.get_errno()
            if err == errno.ERANGE:
                length *= 2
            elif err == errno.ENODATA:
                return None
            else:
                raise IOError(err,os.strerror(err),path)


def ismergerfs(path):
    try:
        lgetxattr(path,"user.mergerfs.fullpath")
        return True
    except IOError as e:
        return False


def setstat(stat,paths):
    for path in paths:
        try:
            os.chmod(path,stat.st_mode)
            os.chown(path,stat.st_uid,stat.st_gid);
            print("set %s > uid: %d gid: %d mode: %o" %
                  (path,stat.st_uid,stat.st_gid,stat.st_mode))
        except Exception as e:
            print("%s" % e)


def stats_different(stats):
    base = stats[0]
    for stat in stats:
        if ((stat.st_mode == base.st_mode) and
            (stat.st_uid  == base.st_uid)  and
            (stat.st_gid  == base.st_gid)):
            continue
        return True
    return False

def size_equal(stats):
    base = stats[0]
    for stat in stats:
        if stat.st_size != base.st_size:
            return False
    return True

def print_stats(Files,Stats):
    for i in range(0,len(Files)):
        print("  %i: %s" % (i,Files[i].decode(errors='backslashreplace')))
        data = ("   - uid: {0:5}; gid: {1:5}; mode: {2:6o}; "
                "size: {3:10}; mtime: {4}").format(
            Stats[i].st_uid,
            Stats[i].st_gid,
            Stats[i].st_mode,
            Stats[i].st_size,
            Stats[i].st_mtime)
        print (data)


def noop_fix(paths,stats):
    pass


def manual_fix(paths,stats):
    done = False
    while not done:
        try:
            value = input('Which is correct?: ')
            value = int(value)
            if((value >= len(paths)) or (value < 0)):
                print("Input error: enter a value [0,%d]" % (len(paths)-1))
                continue
            setstat(stats[value],paths)
            done = True
        except Exception as e:
            print("%s" % e)
            done = True


def newest_fix(paths,stats):
    stats.sort(key=lambda stat: stat.st_mtime)
    try:
        newest = stats[-1]
        setstat(newest,paths)
    except Exception as e:
        print("%s" % e)


def nonroot_fix(paths,stats):
    try:
        for stat in stats:
            if stat.st_uid != 0:
                setstat(stat,paths)
                return
        return newest_fix(paths,stats)
    except Exception as e:
        print("%s" % e)


def getfixfun(name):
    if name == 'manual':
        return manual_fix
    elif name == 'newest':
        return newest_fix
    elif name == 'nonroot':
        return nonroot_fix
    return noop_fix


def check_consistancy(fullpath,verbose,size,fix):
    paths = lgetxattr(fullpath,"user.mergerfs.allpaths")
    if not paths:
        return
    paths = paths.split(b'\0')
    if len(paths) <= 1:
        return

    stats = [os.stat(path) for path in paths]
    if (size and not size_equal(stats)):
        return
    if not stats_different(stats):
        return

    print("%s" % fullpath)
    if verbose:
        print_stats(paths,stats)
    fix(paths,stats)


def buildargparser():
    parser = argparse.ArgumentParser(description='audit a mergerfs mount for inconsistencies')
    parser.add_argument('dir',type=str,
                        help='starting directory')
    parser.add_argument('-v','--verbose',action='store_true',
                        help='print details of audit item')
    parser.add_argument('-s','--size',action='store_true',
                        help='only consider if the size is the same')
    parser.add_argument('-f','--fix',choices=['manual','newest','nonroot'],
                        help='fix policy')
    return parser


def main():
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer,
                                  encoding='utf8',
                                  errors='backslashreplace',
                                  line_buffering=True)
    sys.stderr = io.TextIOWrapper(sys.stderr.buffer,
                                  encoding='utf8',
                                  errors='backslashreplace',
                                  line_buffering=True)

    parser = buildargparser()
    args = parser.parse_args()

    if args.fix:
        args.verbose = True

    fix = getfixfun(args.fix)

    args.dir = os.path.realpath(args.dir)
    if not ismergerfs(args.dir):
        print("%s is not a mergerfs directory" % args.dir)
        sys.exit(1)

    try:
        size = args.size
        verbose = args.verbose
        for (dirname,dirnames,filenames) in os.walk(args.dir):
            fulldirpath = os.path.join(args.dir,dirname)
            check_consistancy(fulldirpath,verbose,size,fix)
            for filename in filenames:
                fullpath = os.path.join(fulldirpath,filename)
                check_consistancy(fullpath,verbose,size,fix)
    except KeyboardInterrupt:
        pass
    except IOError as e:
        if e.errno == errno.EPIPE:
            pass
        else:
            raise

    sys.exit(0)


if __name__ == "__main__":
    main()