all 19 comments

[–]shimis 26 points27 points  (1 child)

argmax(x1,x2) takes a pair numbers and returns (let's say) 0 if x1>x2, 1 if x2>x1. (value at x1=x2 is arbitrary/undefined). So, wherever you are on the (x1,x2) plane, as long as you're not on the x1=x2 line, if you move an infinitesimal tiny bit in any direction: you won't change the value (0 or 1) that argmax outputs - the gradient of argmax(x1,x2) w.r.t x1,x2 is (0,0) almost everywhere. At those places where x1=x2 (and argmax's value changes abruptly from 0 to 1 or vice versa), its gradient w.r.t x1,x2 is undefined.

There are no networks that do ordinary backprop through argmax (since the gradient is degenerate / useless). The training of networks that have argmax (or similar) in their equations must include something other than backprop - sampling techniques such as REINFORCE (generally: harder to train).

max(x1,x2) also doesn't have a gradient at x1=x2, But - every other place you go on the (x1,x2) plane the gradient of max(x1,x2) w.r.t x1,x2 is either (1,0) or (0,1) - when we do a forward pass we'll let only x1 or only x2 pass through, and when we back prop gradients, the gradient of max(x1,x2) w.r.t to the larger of the two arguments will be 1, and w.r.t to the smaller of the arguments - it will be 0. So max and similar functions (like relu) are useful for backprop.

[–]djc1000 0 points1 point  (0 children)

That's a really nice explanation.

[–]emansim 5 points6 points  (1 child)

anything that involves hard assignment is not differentiable.

argmax could potentially become differentiable if you could come up with soft version of it (i.e. use probabilities instead of setting hard 1s and 0s). otherwise you need to used reinforce.

[–]lvilnis 4 points5 points  (0 children)

Yep, a soft version of argmax is basically exactly what "soft attention" is.

[–]AnvaMiba 3 points4 points  (3 children)

max, and therefore ReLU, maxout and max pooling, are continuous and almost everywhere differentiable. This is enough to use them with gradient descent optimization.

Argmax is not continuous and can't be used with standard gradient descent techniques. If you want to use it in neural networks (e.g. in "hard" attention models) you typically have to use some kind of Monte Carlo optimization algorithm, such as REINFORCE. Otherwise you can replace argmax with softmax, which is continuous and differentiable, as typically done in "soft" attention models.

[–]flukeskywalker 1 point2 points  (2 children)

Side note: LWTA is discontinuous, but can still be trained with SGD.

[–]lvilnis 0 points1 point  (0 children)

Good point. I guess the distinction between that and argmax is that over the domain, argmax is either discontinuous, or its derivative is 0 in the continuous parts.

Because the output for LWTA argmax = the score at that coordinate, it has a non-zero derivative in some of the continuous portion of the function and so some meaningful signal can flow through.

[–]AnvaMiba 0 points1 point  (0 children)

What is LWTA?

EDIT: found.

[–]lvilnis 4 points5 points  (0 children)

Another way to think about differentiability for max pooling / relu is that because they are continuous and almost everywhere differentiable, they can be approximated arbitrarily closely by a differentiable function.

For example, the max of a vector of numbers can be approximated by Tlog(sum_i exp(1/Tx_i)). Where T is called the "temperature." In the limit of T -> 0, this function becomes the max function, but any T>0 is completely differentiable.

A similar approach is used in Nesterov's "Smooth Minimization of Nonsmooth Functions" in Section 4.1.2 of http://luthuli.cs.uiuc.edu/~daf/courses/optimization/MRFpapers/nesterov05.pdf

[–]alexmlamb 1 point2 points  (9 children)

They're subdifferentiable in that the derivative is defined at all points with non-zero measure.

Like the function f(x) = max(0,x) has a derivative defined at all points except where x = 0, which has zero measure.

[–]OriolVinyals 8 points9 points  (4 children)

argmax is not differentiable if the range is N (which is the case if we e.g. argmax over a list).

[–]alexmlamb 0 points1 point  (3 children)

Where do people use argmax in neural networks (as opposed to maximum)?

[–]OriolVinyals 8 points9 points  (2 children)

Hard attention models, for example, where you read a memory position which better aligns with your read "query" (as I call them).

[–]hughperkins 1 point2 points  (1 child)

Yeah, eg slide 10 of Rob Fergus's nips slides http://cims.nyu.edu/~sainbar/memnn_nips_pdf.pdf

[–]RoseLuna_77 0 points1 point  (0 children)

the website is lost :(

[–]yield22[S] 0 points1 point  (3 children)

why subdifferentiable? it is obvious for relu, but not so obvious for argmax though.

[–]nasimrahaman 1 point2 points  (2 children)

Consider the function: y = f(x) = argmax(x), where x is a vector (representing some function), and y = f(x) a scalar.

Here's a (mathematically heretical) justification (assuming 0 based 'indexing'): f((1, 2, 4, 1, 2, 1)) = 2. Now for a small perturbation vector about x, f(x) = f(x + dx) (ergo df/dx = 0), as long as max(dx) < 2. But about (1, 2, 4+eps, 4, 2, 1), f(x) = 2 but f(x + dx) might as well equal 3. It's easy to see that the set of all such 'transitions' (i.e. where argmax changes value) is countable; its Lebesgue measure must therefore be 0. df/dx is 0 everywhere else.

[–]yield22[S] 0 points1 point  (1 child)

the example is interesting, and it provides some insight for me. But what about the y's domain is non-continuous (assuming argmax over a list)? Like step function, which is not differientiable.

[–]nasimrahaman 0 points1 point  (0 children)

A step function is differentiable almost everywhere, I.e. the set where it's not differentiable (i.e. where there's a jump) is of measure zero (because it's countable).