import React, {useEffect, useContext, useState} from 'react'
import {Alert, Anchor, Divider, message, notification, Space, Spin, Table, Typography, Collapse, Upload} from 'antd';
import bionlp24Api from "services/misc/bionlp24";
import LeaderboardDragger from "./LeaderboardDragger"
import SyntaxHighlighter from "react-syntax-highlighter";
import GlobalContext from "contexts/GlobalContext";
import Retrieve from "./Retrieve";

const {Panel} = Collapse;
const {Title, Paragraph, Link, Text} = Typography;
const {Link: LinkAnchor} = Anchor;

const code_process = 'from nltk.tokenize import wordpunct_tokenize\n' +
    'def process_impression(impression):\n' +
    '    impression = impression.lower()\n' +
    '    return \' \'.join(wordpunct_tokenize(impression))'


const DATASETS = {
    "findings-public-test-set": 2692,
    "impression-public-test-set": 2967,
    "findings-hidden-test-set": 1063,
    "impression-hidden-test-set": 1428,
}


const openNotificationWithIcon = (type, description) => {
    notification[type]({
        message: 'Notification: ' + type,
        description: description,
        duration: 0,
    });
};


const columnsSplits = [{
    title: 'Split', dataIndex: 'split', key: 'split',
}, {
    title: 'Findings', dataIndex: 'findings', key: 'findings',
}, {
    title: 'Impression', dataIndex: 'impression', key: 'impression',
},
];

const dataSource_baseline = [
    {
        key: '1',
        mode: 'Impression',
        bleu: 0.05015268643813864,
        rougel: 0.2878319023957088,
        bertscore: 0.5194311738014221,
        f1chexbert: 0.5329143538694436,
        f1radgraph: 0.2697430927580646,
    },
    {
        key: '2',
        mode: 'Findings',
        bleu: 0.05393674731632274,
        rougel: 0.25791875010866055,
        bertscore: 0.5027592182159424,
        f1chexbert: 0.49967264632709185,
        f1radgraph: 0.23549075818054982,
    }
];


const columnsSingle = [
    {
        title: 'BLEU4',
        dataIndex: 'bleu',
        key: 'bleu',
        render: (text) =>
            text === null ? 'N/A' : (text * 100).toFixed(2)
    },
    {
        title: 'ROUGEL',
        dataIndex: 'rougel',
        key: 'rougel',
        render: (text) =>
            text === null ? 'N/A' : (text * 100).toFixed(2)
    },
    {
        title: 'Bertscore',
        dataIndex: 'bertscore',
        key: 'bertscore',
        render: (text) =>
            text === null ? 'Computing...' : (text * 100).toFixed(2)
    },
    {
        title: 'F1-cheXbert',
        dataIndex: 'f1chexbert',
        key: 'f1chexbert',
        render: (text) =>
            text === null ? 'Computing...' : text === undefined ? "N/A" : (text * 100).toFixed(2)
    },
    {
        title: 'F1-RadGraph',
        dataIndex: 'f1radgraph',
        key: 'f1radgraph',
        render: (text) =>
            text === null ? 'Computing...' : (text * 100).toFixed(2)
    },
];

const columnsAll = [
    {
        title: 'Team',
        dataIndex: 'team',
        key: 'team',
    },
    {
        title: 'BLEU4',
        dataIndex: 'bleu',
        key: 'bleu',
        render: (text) =>
            text === null ? 'N/A' : (text * 100).toFixed(2)
    },
    {
        title: 'ROUGEL',
        dataIndex: 'rougel',
        key: 'rougel',
        render: (text) =>
            text === null ? 'N/A' : (text * 100).toFixed(2)
    },
    {
        title: 'Bertscore',
        dataIndex: 'bertscore',
        key: 'bertscore',
        render: (text) =>
            text === null ? 'Computing...' : (text * 100).toFixed(2)
    },
    {
        title: 'F1-cheXbert',
        dataIndex: 'f1chexbert',
        key: 'f1chexbert',
        render: (text) =>
            text === null ? 'Computing...' : text === undefined ? "N/A" : (text * 100).toFixed(2),
        // defaultSortOrder: 'descend',
        // sorter: {
        //     compare: (a, b) => a.f1chexbert - b.f1chexbert,
        //     multiple: 1,
        // },
    },
    {
        title: 'F1-RadGraph',
        dataIndex: 'f1radgraph',
        key: 'f1radgraph',
        render: (text) =>
            text === null ? 'Computing...' : (text * 100).toFixed(2),
        defaultSortOrder: 'descend',
        sorter: {
            compare: (a, b) => a.f1radgraph - b.f1radgraph,
            multiple: 1,
        },
    },
];

const columnsAllBaseline = columnsAll.map(column => {
    if (column.dataIndex === 'team') {
        return {
            title: 'Mode',
            dataIndex: 'mode',
            key: 'mode',
        };
    }
    return column;
});

const Leaderboard = () => {

    const [waitForTeam, setWaitForTeam] = useState(false);
    const [waitForResult, setWaitForResult] = useState(false);
    const [hyps, setHyps] = useState(null);
    const [team, setTeam] = useState("");
    const [pin, setPin] = useState("");
    const [filename, setFilename] = useState("");
    const {
        codeStyle,
        errorMessage,
        setErrorMessage
    } = useContext(GlobalContext);

    // current_result
    const [data, setData] = useState(null);
    const [leaderboard, setLeaderboard] = useState({});

    const getLeaderboard = async (item) => {
        try {
            const response = await bionlp24Api.getScores(item);
            if ("error" in response)
                setErrorMessage(response["error"])
            return {[item]: response};
        } catch (error) {
            console.error(error);
            return {[item]: null};
        }
    };

    useEffect(() => {
        setErrorMessage("")
        Promise.all(Object.keys(DATASETS).map(getLeaderboard))
            .then(results => {
                // Combine all the results into a single object
                const newLeaderboard = results.reduce((acc, curr) => ({...acc, ...curr}), {});
                // Update the state with the new leaderboard
                setLeaderboard(newLeaderboard);
            });
    }, [])

    const resetState = () => {
        setData(null)
        setHyps(null)
        setTeam("")
        setPin("")
        setWaitForTeam(false)
        setWaitForResult(false)
        setFilename("")
    }
    const handleFileUpload = async ({fileList: newFileList}, dataset) => {
        resetState()
        setWaitForTeam(true)

        if (newFileList.length > 0) {
            let file = newFileList[0]
            let hyps = await new Promise((resolve) => {
                const reader = new FileReader();
                reader.readAsDataURL(file.originFileObj);
                reader.onload = () => resolve(reader.result);

            })
            hyps = atob(hyps.split(',')[1]).split("\n")

            // remove trailing blank line if present
            if (typeof hyps[hyps.length - 1] === 'string' && hyps[hyps.length - 1] === '') {
                hyps.pop();
            }

            if (hyps.length !== DATASETS[dataset]) {
                message.error(`Num of lines mismatch:  ${hyps.length} (your file) vs ${DATASETS[dataset]} (reference file) !`);
                resetState()
                return false
            }
            setHyps(hyps)
            setFilename(file["name"])
        }
    }
    const resetResults = async () => {
        resetState()
    }
    const submitResults = async (dataset) => {
        setWaitForResult(true)
        setWaitForTeam(false)

        if (!(team) || team === "") {
            message.error("No team provided")
            setWaitForResult(false)
            return
        }

        if (!(pin) || pin === "") {
            message.error("No pin provided")
            setWaitForResult(false)
            return
        }

        if (!Array.isArray(hyps) || hyps.length === 0) {
            message.error("No hyps provided")
            setWaitForResult(false)
            return
        }

        let ret = await bionlp24Api.computeScores(hyps, dataset, pin, team)

        if (!(typeof ret === 'object' && !Array.isArray(ret))) {
            resetState()
            openNotificationWithIcon("error", "Error happened in the scoring, please try again later")
            return
        }

        if ('error' in ret) {
            resetState()
            openNotificationWithIcon("error", ret["error"])
            return
        }
        ret["key"] = 1
        setData([ret])
        setHyps(null)
        setTeam("")
        setPin("")
        setWaitForTeam(false)
        setWaitForResult(false)
        setFilename("")
        openNotificationWithIcon("success", "Result successfully sent. It will soon be updated with the missing values.")

    }
    if (errorMessage) {
        return (
            <Alert
                message="An error has been caught"
                description={`Sorry. We encountered the following error: ${errorMessage}`}
                type="error"
                showIcon
            />
        );
    }

    return (
        <div style={{display: 'flex'}}>
            <div style={{width: 1000}} id={"docPageContent"}>
                <Space direction={"vertical"} size={20}>
                    <Title level={3} id="anchor-splits">1. Splits</Title>

                    <Table dataSource={[
                        {
                            key: '1',
                            split: 'validation',
                            findings: '8,839',
                            impression: '9,331',

                        },
                        {
                            key: '2',
                            split: 'test-public',
                            findings: "2,692",
                            impression: "2,967",
                        },
                        {
                            key: '3',
                            split: 'test-hidden',
                            findings: "1,063",
                            impression: "1,428",
                        }
                    ]
                    } columns={columnsSplits} pagination={false}/>

                    <Title level={3} id="anchor-how-to">2. How to and rules</Title>
                    <Alert
                        message="Rules"
                        description={<>
                            <p>
                                You can make one submission each 8 hours on the public test-set and each hour on the
                                hidden
                                test-set.
                                Submission on the the public test-set set are useful to validate your local results
                                match
                                the results online.</p>
                            <p>
                                You can only use one team name unless you plan on writing multiple papers with
                                substantially
                                different scientific contributions.
                            </p>
                        </>

                        }
                        type="warning"
                        style={{margin: 12}}
                        showIcon
                    />

                    To submit your model generations, please use Section 3 the following way:
                    <Paragraph>
                        <ol>
                            <li>
                                Select the dataset you want to make the submission for.
                            </li>
                            <li>
                                Upload a text file containing one impression or findings per line. This file should be
                                aligned with                                the provided datasets as such: <br/>

                                <SyntaxHighlighter
                                    customStyle={{textAlign: "left"}}
                                    language="bash"
                                    style={codeStyle}>
                                    {
                                        'import datasets' +
                                        'dataset = datasets.load_dataset("StanfordAIMI/interpret-cxr-test-hidden/")\n' +
                                        'findings_to_generate = [s for s in dataset["test"] if s["findings"]]'
                                    }
                                </SyntaxHighlighter>
                            </li>
                            <li>
                                Enter a team name and a pin code. If this is your first submission, your team name will
                                be
                                created. The pin code ensures that only you can submit for your team, <b>please remember
                                it.</b>
                            </li>
                            <Alert
                                message="Please Read"
                                description={<>The team name is the name that will prepend to the title of your paper at
                                    BioNLP. You can see example of the previous edition <Link
                                        href={"https://aclanthology.org/volumes/2023.bionlp-1/"}>here</Link> (look for "
                                    RadSum23").</>}
                                type="warning"
                                style={{margin: 12}}
                                showIcon
                            />
                            <li>
                                Click submit and wait for the success message stating that your submission has been
                                recorded. Because some metrics are computationally expensive, your submission will be
                                scored
                                asynchronously.
                            </li>
                        </ol>

                        Finally, each impression of your submission will be processed as such before scoring to match
                        the
                        ground-truth processing on our side:
                        <SyntaxHighlighter customStyle={{textAlign: "left"}} language="python" style={codeStyle}>
                            {code_process}
                        </SyntaxHighlighter>
                        And evaluated as such:
                        <SyntaxHighlighter customStyle={{textAlign: "left"}} language="python" style={codeStyle}>
                            {'import json\n' +
                                'import logging\n' +
                                'from vilmedic.blocks.scorers.scores import compute_scores\n' +
                                '\n' +
                                'refs = [\n' +
                                '    "The lungs are clear. The cardiomediastinal silhouette is within normal limits. No acute osseous abnormalities.",\n' +
                                '    "The lungs are clear.There is no pleural effusion or pneumothorax.The cardiomediastinal silhouette is normal."\n' +
                                ']\n' +
                                'hyps = [\n' +
                                '    "The lungs are clear. There is no pleural effusion or pneumothorax. The cardiomediastinal silhouette is normal.",\n' +
                                '    "The lungs are clear. The cardiomediastinal silhouette is within normal limits. No acute osseous abnormalities."\n' +
                                ']\n' +
                                'print("Computing metrics, this can take a while...")\n' +
                                'print(json.dumps(compute_scores(["ROUGEL", "bertscore", "radgraph", "BLEU", "chexbert"],\n' +
                                '                                refs=refs,\n' +
                                '                                hyps=hyps,\n' +
                                '                                split=None,\n' +
                                '                                seed=None,\n' +
                                '                                config=None,\n' +
                                '                                epoch=None,\n' +
                                '                                logger=logging.getLogger(__name__),\n' +
                                '                                dump=False),\n' +
                                '                 indent=4)\n' +
                                '      )\n'}
                        </SyntaxHighlighter>
                        In this challenge, the chosen radgraph metric is "radgraph_partial" from the <a
                        href="https://pypi.org/project/radgraph/">radgraph==0.1.2</a> and "chexbert-all_micro
                        avg_f1-score" for <a href={"https://pypi.org/project/f1chexbert/"}>f1chexbert</a>. The bertscore
                        is computed as detailed in this <a
                        href={"https://github.com/jbdel/vilmedic/blob/main/vilmedic/blocks/scorers/NLG/bertscore/bertscore.py"}>script</a>.


                    </Paragraph>

                    <Title level={3} id="anchor-submit">3. Submit</Title>

                    <LeaderboardDragger
                        handleFileUpload={handleFileUpload}
                        loading={waitForResult || waitForTeam}
                        hyps={hyps}
                        setTeam={setTeam}
                        setPin={setPin}
                        pin={pin}
                        filename={filename}
                        submitResults={submitResults}
                        resetResults={resetResults}
                        datasets={Object.keys(DATASETS)}
                    />

                    {waitForResult &&
                        <div style={{padding: 50}}>
                            <Spin tip="Computing scores, this could take a few minutes...">
                                <div className="content"/>
                            </Spin>
                        </div>
                    }
                    {data &&
                        <div>
                            <Divider orientation="left">Submission</Divider>
                            <Table columns={columnsSingle} dataSource={data}/>
                        </div>
                    }

                    <Title level={3} id="anchor-retrieve">4. Retrieve my submissions</Title>
                    <Retrieve/>

                    <Title level={3} id="anchor-leaderboards">5. Leaderboards</Title>


                    {[
                        ["findings-public-test-set", "(findings) public test-set"],
                        ["impression-public-test-set", "(impression) public test-set"],
                        ["findings-hidden-test-set", "(findings) hidden test-set"],
                        ["impression-hidden-test-set", "(impression) hidden test-set"],
                    ].map((dataset) =>
                        <>
                            <Title level={5}>{dataset[1]} ({DATASETS[dataset[0]]} samples)</Title>
                            <div>
                                <Divider orientation="left">Leaderboard</Divider>
                                <Table columns={columnsAll}
                                       dataSource={leaderboard[dataset[0]]}
                                       loading={(Object.keys(leaderboard).length === 0)}
                                />
                            </div>

                        </>
                    )}

                    <Title level={3} id="anchor-baseline">6. Baseline</Title>
                    <Paragraph>
                        The baseline has been trained with <img style={{"width": 14}} src={"/favicon/favicon-64x64.png"}
                                                                alt={"vilmedic-front"}/> ViLMedic
                        on top of the <img style={{"width": 14}} src={"/images/hf.png"} alt={"hf"}/> huggingface
                        library.
                        The baseline consist of a <a
                        href={"https://huggingface.co/docs/transformers/v4.39.1/en/model_doc/swinv2"}>swinv2</a> visual
                        encoder and two-layered bert as decoder.
                        The models are hosted on the model-zoo and can be evaluated as such:
                    </Paragraph>
                    <Collapse>
                        <Panel header="Evaluation script" key="1">
                            <SyntaxHighlighter
                                customStyle={{textAlign: "left"}}
                                language="python"
                                style={codeStyle}>
                                {
                                    'import json\n' +
                                    'import datasets\n' +
                                    'import tqdm\n' +
                                    'import logging\n' +
                                    'import torch\n' +
                                    'from PIL import Image\n' +
                                    'from torch.utils.data import DataLoader\n' +
                                    'from nltk.tokenize import wordpunct_tokenize\n' +
                                    'from vilmedic.blocks.scorers.scores import compute_scores\n' +
                                    'from transformers import BertTokenizer, ViTImageProcessor, VisionEncoderDecoderModel, GenerationConfig\n' +
                                    '\n' +
                                    'dataset = datasets.load_dataset("StanfordAIMI/interpret-cxr", split="validation")\n' +
                                    '\n' +
                                    'for mode in ["findings", "impression"]:\n' +
                                    '    # Model\n' +
                                    '    model = VisionEncoderDecoderModel.from_pretrained(f"IAMJB/interpret-cxr-{mode}-baseline").eval()\n' +
                                    '    tokenizer = BertTokenizer.from_pretrained(f"IAMJB/interpret-cxr-{mode}-baseline")\n' +
                                    '    image_processor = ViTImageProcessor.from_pretrained(f"IAMJB/interpret-cxr-{mode}-baseline")\n' +
                                    '\n' +
                                    '    # Dataset\n' +
                                    '    generation_args = {\n' +
                                    '        "bos_token_id": model.config.bos_token_id,\n' +
                                    '        "eos_token_id": model.config.eos_token_id,\n' +
                                    '        "pad_token_id": model.config.pad_token_id,\n' +
                                    '        "num_return_sequences": 1,\n' +
                                    '        "max_length": 128,\n' +
                                    '        "use_cache": True,\n' +
                                    '        "beam_width": 2,\n' +
                                    '    }\n' +
                                    '\n' +
                                    '    # Create DataLoader\n' +
                                    '    filtered_dataset = [sample for sample in dataset if sample[mode] != ""]\n' +
                                    '    data_loader = DataLoader(filtered_dataset, batch_size=128, collate_fn=lambda x: x)\n' +
                                    '\n' +
                                    '    # Inference\n' +
                                    '    refs = []\n' +
                                    '    hyps = []\n' +
                                    '    with torch.no_grad():\n' +
                                    '        for batch in tqdm.tqdm(data_loader, total=len(data_loader.dataset) // data_loader.batch_size):\n' +
                                    '            images = [Image.open(sample["images_path"][0]).convert(\'RGB\') for sample in batch]\n' +
                                    '            pixel_values = image_processor(images, return_tensors="pt").pixel_values\n' +
                                    '            # Generate predictions\n' +
                                    '            generated_ids = model.generate(\n' +
                                    '                pixel_values,\n' +
                                    '                generation_config=GenerationConfig(\n' +
                                    '                    **{**generation_args, "decoder_start_token_id": tokenizer.cls_token_id})\n' +
                                    '            )\n' +
                                    '            generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n' +
                                    '            refs.extend([\' \'.join(wordpunct_tokenize(s[mode].lower())) for s in batch])\n' +
                                    '            hyps.extend(generated_texts)\n' +
                                    '\n' +
                                    '    print("Computing metrics, this can take a while...")\n' +
                                    '    print(json.dumps(compute_scores(["ROUGEL", "bertscore", "radgraph", "BLEU", "chexbert"],\n' +
                                    '                                    refs=refs,\n' +
                                    '                                    hyps=hyps,\n' +
                                    '                                    split=None,\n' +
                                    '                                    seed=None,\n' +
                                    '                                    config=None,\n' +
                                    '                                    epoch=None,\n' +
                                    '                                    logger=logging.getLogger(__name__),\n' +
                                    '                                    dump=False),\n' +
                                    '                     indent=4)\n' +
                                    '          )\n' +
                                    '\n'
                                }
                            </SyntaxHighlighter>
                        </Panel>
                    </Collapse>

                    <Text> They report the following scores <b>on the validation set</b>:
                    </Text>

                    <Table columns={columnsAllBaseline} dataSource={dataSource_baseline}/>
                    The training command is the following:
                    <Collapse>
                        <Panel header="Training command" key="1">
                            <SyntaxHighlighter
                                customStyle={{textAlign: "left"}}
                                language="bash"
                                style={codeStyle}>
                                {
                                    'python bin/train.py config/RRG/baseline-mimic-HF.yml \\\n' +
                                    '        dataset.seq.processing=ifcc_clean_report \\\n' +
                                    '        dataset.seq.hf_dataset=StanfordAIMI/interpret-cxr \\\n' +
                                    '        dataset.seq.hf_field=findings \\\n' +
                                    '        dataset.seq.hf_filter=\'lambda e:e["findings"]\' \\\n' +
                                    '        dataset.seq.tokenizer_max_len=128 \\\n' +
                                    '        dataset.seq.file=null \\\n' +
                                    '        dataset.image.hf_dataset=StanfordAIMI/interpret-cxr \\\n' +
                                    '        dataset.image.hf_field=images \\\n' +
                                    '        dataset.image.hf_filter=\'lambda e:e["findings"]\' \\\n' +
                                    '        dataset.image.multi_image=1 \\\n' +
                                    '        dataset.image.image_path=/home/users/jbdel/scratch/ \\\n' +
                                    '        dataset.image.file=null \\\n' +
                                    '        model.proto=RRG_HF \\\n' +
                                    '        model.vision=microsoft/swinv2-tiny-patch4-window8-256 \\\n' +
                                    '        model.decoder.proto_config_args.num_hidden_layers=2 \\\n' +
                                    '        trainor.batch_size=16 \\\n' +
                                    '        trainor.grad_accu=8 \\\n' +
                                    '        trainor.optim_params.lr=0.0003 \\\n' +
                                    '        trainor.optimizer=RAdam \\\n' +
                                    '        trainor.early_stop_metric=bertscore \\\n' +
                                    '        trainor.early_stop=10 \\\n' +
                                    '        validator.batch_size=8 \\\n' +
                                    '        validator.beam_width=2 \\\n' +
                                    '        validator.metrics=[bertscore] \\\n' +
                                    '        validator.splits=[validation] \\\n' +
                                    '        ckpt_dir=ckpt \\\n' +
                                    '        name=interpret_nll_findings'
                                }
                            </SyntaxHighlighter>
                        </Panel>
                    </Collapse>


                </Space>
            </div>
            <div style={{marginLeft: '10px'}}>
                <Anchor>
                    <LinkAnchor href="#anchor-splits" title="Splits"/>
                    <LinkAnchor href="#anchor-how-to" title="How-to"/>
                    <LinkAnchor href="#anchor-submit" title="Submit"/>
                    <LinkAnchor href="#anchor-retrieve" title="Retrieve"/>
                    <LinkAnchor href="#anchor-leaderboards" title="Leaderboards"/>
                    <LinkAnchor href="#anchor-baseline" title="Baseline"/>
                </Anchor>
            </div>
        </div>
    )

}

export default Leaderboard

