I think you can do this one exactly using neural networks. Using the bias term and a x² activation, you can have each neuron of the first layer compute the distance of the input to each dataset point (saved in the bias vectors). This would lead to a (gigantic) matrix with the distance of the current point to each point in the dataset.
Then, you could use a soft-min activation to extract the nearest neighbours.
Sounds pretty stupid and inefficient, but I think one can mimic the algorithm exactly.