Forward pass execution with custom forward derivative


I noticed that when specifying a custom forward derivative for a function, only the gradient is evaluated but not the function itself. Maybe I’m doing something wrong, so I wanted to check here rather than creating an issue right away.

A minimal example is the following:

#include <stdio.h>

int enzyme_const, enzyme_dup;
extern void __enzyme_fwddiff(void*, ...);

void square(double *y, double *x) {
    *y = *x * *x;

void grad_square(double *y, double *dy, double *x, double *dx) {
    *dy = 3 * *x; // Purposefully wrong

void* __enzyme_register_derivative_square[] = {

int main() {
    double x = 5.0;
    double dx = 1.0;
    double y = 0.0;
    double dy = 0.0;
    __enzyme_fwddiff((void*) square, enzyme_dup, &y, &dy, enzyme_dup, &x, &dx);
    printf("f(%f) = %f, df(%f) = %f\n", x, y, x, dy);

which prints out

f(5.000000) = 0.000000, df(5.000000) = 15.000000

Without the custom derivative, the square function itself is getting evaluated:

f(5.000000) = 25.000000, df(5.000000) = 10.000000

Is this expected behavior or a bug?


Oh actually I realized my mistake after typing it: the grad_square custom derivative should probably also include the forward pass to be specified. So something like

void grad_square(double *y, double *dy, double *x, double *dx) {
    square(y, x); // Include forward pass
    *dy = 3 * *x; // Purposefully wrong

is correct. My bad, I guess I was confused by the custom registration but this actually makes more sense that you’d want to specify both function and derivative in one custom call.

I don’t need help on this but I’ll leave it up in case it’s helpful to anyone! Doesn’t hurt to see more simple examples I guess.

1 Like

Yeah, part of the reasoning for this is suppose there’s an instruction (say read from stdin) done as part of the original function. It wouldn’t be legal to do that twice, but having both the original code and forward derivative in one place enables that data to be shared.

That doesn’t mean that we can’t have a custom call convention that has the original code separate, but the default is to combine them.

1 Like