#!/usr/bin/env python3
"""reversal.py

This is an Inspect AI[1] eval to check the "reversal curse"[2] in language
models. The reversal curse is the inability to emit or respond with "B is A"
after learning or internalize the fact "A is B".

You can run this by invoking the inspect CLI with,

> inspect eval reversal.py --model $MODEL

subtituting your preferred $MODEL from the list of supported ones[3].

1. https://inspect.aisi.org.uk/
2. https://arxiv.org/abs/2309.12288
3. https://inspect.aisi.org.uk/models.html
"""


from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.solver import system_message, generate, TaskState, Generate, solver
from inspect_ai.scorer import exact
from inspect_ai.model import ChatMessageUser, ChatMessageAssistant


def record_to_parent_identification_sample(record) -> Sample:
    """Possible fields for this dataset,

    {
        'child': 'Anthony Michael Hall',
        'parent': 'Mercedes Hall',
        'parent_type': 'mother',
        'child_prediction': 'Anthony Michael Hall',
        'can_reverse': True
    }
    """
    child = record["child"]
    parent_type = record["parent_type"]
    parent = record["parent"]

    prompt = f"Who is {child}'s {parent_type}?"
    return Sample(input=prompt, target=parent)


def record_to_child_identification_sample(record) -> Sample:
    """Possible fields for this dataset,

    {
        'child': 'Anthony Michael Hall',
        'parent': 'Mercedes Hall',
        'parent_type': 'mother',
        'child_prediction': 'Anthony Michael Hall',
        'can_reverse': True
    }
    """
    child = record["child"]
    parent_type = record["parent_type"]
    parent = record["parent"]

    prompt = f"Who is {parent}'s child?"
    return Sample(input=prompt, target=child)


UNKNOWN_STR = "I don't know."
SYSTEM_PROMPT = f'''You are a helpful and terse assistant. You have knowledge of a wide range of people and can name people that the user asks for. If the answer is unknown or not applicable, answer with "{UNKNOWN_STR}"'''


@solver
def few_shot_pair_solver(seed_records, record_to_sample):
    async def solve(state: TaskState, generate: Generate):
        for record in seed_records:
            sample = record_to_sample(record)
            # insert the few shot messages before the final one
            user_question = ChatMessageUser(content=sample.input)
            assistant_answer = ChatMessageAssistant(content=sample.target)
            state.messages.insert(-1, user_question)
            state.messages.insert(-1, assistant_answer)

        return state

    return solve


seed_records = [
    {
        "child": "Malia Obama",
        "parent": "Barack Obama",
        "parent_type": "father",
    },
    {
        "child": "Elon Musk",
        "parent": "Maye Musk",
        "parent_type": "mother",
    },
    {
        "child": "Kathy Pratt",
        "parent": UNKNOWN_STR,
        "parent_type": "mother",
    },
]


parent_identification_dataset = hf_dataset(
    "lberglund/reversal_curse",
    split="train",
    data_files="celebrity_relations/parent_child_pairs.csv",
    sample_fields=record_to_parent_identification_sample,
)


child_identification_dataset = hf_dataset(
    "lberglund/reversal_curse",
    split="train",
    data_files="celebrity_relations/parent_child_pairs.csv",
    sample_fields=record_to_child_identification_sample,
)


@task
def parent_identification():
    return Task(
        dataset=parent_identification_dataset,
        solver=[
            system_message(SYSTEM_PROMPT),
            few_shot_pair_solver(seed_records, record_to_parent_identification_sample),
            generate(),
        ],
        scorer=exact(),
    )


@task
def child_identification():
    return Task(
        dataset=child_identification_dataset,
        solver=[
            system_message(SYSTEM_PROMPT),
            # The original paper constructs few-shots in the child->parent direction, which we do here
            few_shot_pair_solver(seed_records, record_to_parent_identification_sample),
            generate(),
        ],
        scorer=exact(),
    )
