001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.ArrayList;
016import java.util.Collections;
017import java.util.Comparator;
018
019import org.apache.commons.lang.ArrayUtils;
020
021class InterpolatedPoint {
022
023        Dataset realPoint;
024        Dataset coordPoint;
025
026        public InterpolatedPoint(Dataset realPoint, Dataset coordPoint) {
027                this.realPoint = realPoint;
028                this.coordPoint = coordPoint;
029        }
030
031        public Dataset getRealPoint() {
032                return realPoint;
033        }
034
035        public Dataset getCoordPoint() {
036                return coordPoint;
037        }
038        
039        @Override
040        public String toString() {
041                String realString = "[ " + realPoint.getDouble(0);
042                for(int i = 1; i < realPoint.getShapeRef()[0]; i++) {
043                        realString += " , " + realPoint.getDouble(i);
044                }
045                realString += " ]";
046                
047                String coordString = "[ " + coordPoint.getDouble(0);
048                for(int i = 1; i < coordPoint.getShapeRef()[0]; i++) {
049                        coordString += " , " + coordPoint.getDouble(i) ;
050                }
051                coordString += " ]";
052                
053                return realString + " : " + coordString;
054        }
055
056}
057
058public class InterpolatorUtils {
059
060        public static Dataset regridOld(Dataset data, Dataset x, Dataset y,
061                        Dataset gridX, Dataset gridY) throws Exception {
062                
063                DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, gridX.getShapeRef()[0], gridY.getShapeRef()[0]);
064                
065                IndexIterator itx = gridX.getIterator();
066                
067                // need a list of lists to store points
068                ArrayList<ArrayList<InterpolatedPoint>> pointList = new ArrayList<ArrayList<InterpolatedPoint>>();
069                
070                while(itx.hasNext()){
071                        // Add a list to contain all the points which we find
072                        pointList.add(new ArrayList<InterpolatedPoint>());
073                        
074                        int xindex = itx.index;
075                        double xPos = gridX.getDouble(xindex);
076                        
077                        IndexIterator ity = gridY.getIterator();
078                        while(ity.hasNext()){
079                                int yindex = ity.index;
080                                System.out.println("Testing : "+xindex+","+yindex);
081                                double yPos = gridX.getDouble(yindex);
082                                result.set(getInterpolated(data, x, y, xPos, yPos), yindex, xindex);
083                                
084                        }
085                }
086                return result;
087        }
088        
089        
090        
091        
092        public static Dataset selectDatasetRegion(Dataset dataset, int x, int y, int xSize, int ySize) {
093                int startX = x - xSize;
094                int startY = y - ySize;
095                int endX = x + xSize + 1;
096                int endY = y + ySize +1;
097                
098                int shapeX = dataset.getShapeRef()[0];
099                int shapeY = dataset.getShapeRef()[1];
100                
101                // Do edge checking
102                if (startX < 0) {
103                        startX = 0;
104                        endX = 3;
105                } 
106                
107                if (endX > shapeX) {
108                        endX = shapeX;
109                        startX = endX-3;
110                }
111                
112                if (startY < 0) {
113                        startY = 0;
114                        endY = 3;
115                }
116                
117                if (endY > shapeY) {
118                        endY = shapeY;
119                        startY = endY-3;
120                }
121                
122                int[] start = new int[] { startX, startY };
123                int[] stop = new int[] { endX, endY };
124                
125                
126                return dataset.getSlice(start, stop, null);
127        }
128        
129        private static double getInterpolated(Dataset val, Dataset x, Dataset y, double xPos,
130                        double yPos) throws Exception {
131                
132                // initial guess
133                Dataset xPosDS = x.getSlice(new int[] {0,0}, new int[] {x.getShapeRef()[0],1}, null).isubtract(xPos);
134                int xPosMin = xPosDS.minPos()[0];
135                Dataset yPosDS = y.getSlice(new int[] {xPosMin,0}, new int[] {xPosMin+1,y.getShapeRef()[1]}, null).isubtract(yPos);
136                int yPosMin = yPosDS.minPos()[0];
137                
138                
139                // now search around there 5x5
140                
141                Dataset xClipped = selectDatasetRegion(x,xPosMin,yPosMin,2,2);
142                Dataset yClipped = selectDatasetRegion(y,xPosMin,yPosMin,2,2);
143                
144                // first find the point in the arrays nearest to the point
145                Dataset xSquare = Maths.subtract(xClipped, xPos).ipower(2);
146                Dataset ySquare = Maths.subtract(yClipped, yPos).ipower(2);
147
148                Dataset total = Maths.add(xSquare, ySquare);
149
150                int[] pos = total.minPos();
151
152                // now pull out the region around that point, as a 3x3 grid     
153                Dataset xReduced = selectDatasetRegion(x, pos[0], pos[1], 1, 1);
154                Dataset yReduced = selectDatasetRegion(y, pos[0], pos[1], 1, 1);
155                Dataset valReduced = selectDatasetRegion(val, pos[0], pos[1], 1, 1);
156
157                return getInterpolatedResultFromNinePoints(valReduced, xReduced, yReduced, xPos, yPos);
158        }
159
160        private static double getInterpolatedResultFromNinePoints(Dataset val, Dataset x, Dataset y,
161                        double xPos, double yPos) throws Exception {
162                
163                // First build the nine points
164                InterpolatedPoint p00 = makePoint(x, y, 0, 0);
165                InterpolatedPoint p01 = makePoint(x, y, 0, 1);
166                InterpolatedPoint p02 = makePoint(x, y, 0, 2);
167                InterpolatedPoint p10 = makePoint(x, y, 1, 0);
168                InterpolatedPoint p11 = makePoint(x, y, 1, 1);
169                InterpolatedPoint p12 = makePoint(x, y, 1, 2);
170                InterpolatedPoint p20 = makePoint(x, y, 2, 0);
171                InterpolatedPoint p21 = makePoint(x, y, 2, 1);
172                InterpolatedPoint p22 = makePoint(x, y, 2, 2);
173
174                // now try every connection and find points that intersect with the interpolated value
175                ArrayList<InterpolatedPoint> points = new ArrayList<InterpolatedPoint>();
176
177                InterpolatedPoint A = get1DInterpolatedPoint(p00, p10, 0, xPos);
178                InterpolatedPoint B = get1DInterpolatedPoint(p10, p20, 0, xPos);
179                InterpolatedPoint C = get1DInterpolatedPoint(p00, p01, 0, xPos);
180                InterpolatedPoint D = get1DInterpolatedPoint(p10, p11, 0, xPos);
181                InterpolatedPoint E = get1DInterpolatedPoint(p20, p21, 0, xPos);
182                InterpolatedPoint F = get1DInterpolatedPoint(p01, p11, 0, xPos);
183                InterpolatedPoint G = get1DInterpolatedPoint(p11, p21, 0, xPos);
184                InterpolatedPoint H = get1DInterpolatedPoint(p01, p02, 0, xPos);
185                InterpolatedPoint I = get1DInterpolatedPoint(p11, p12, 0, xPos);
186                InterpolatedPoint J = get1DInterpolatedPoint(p21, p22, 0, xPos);
187                InterpolatedPoint K = get1DInterpolatedPoint(p02, p12, 0, xPos);
188                InterpolatedPoint L = get1DInterpolatedPoint(p12, p22, 0, xPos);
189
190                // Now add any to the list which are not null
191                if (A != null)
192                        points.add(A);
193                if (B != null)
194                        points.add(B);
195                if (C != null)
196                        points.add(C);
197                if (D != null)
198                        points.add(D);
199                if (E != null)
200                        points.add(E);
201                if (F != null)
202                        points.add(F);
203                if (G != null)
204                        points.add(G);
205                if (H != null)
206                        points.add(H);
207                if (I != null)
208                        points.add(I);
209                if (J != null)
210                        points.add(J);
211                if (K != null)
212                        points.add(K);
213                if (L != null)
214                        points.add(L);
215
216                // if no intercepts, then retun NaN;
217                if (points.size() == 0) return Double.NaN;
218                
219                InterpolatedPoint bestPoint = null;
220
221                // sort the points by y
222                Collections.sort(points, new Comparator<InterpolatedPoint>() {
223
224                        @Override
225                        public int compare(InterpolatedPoint o1, InterpolatedPoint o2) {
226                                return (int) Math.signum(o1.realPoint.getDouble(1) - o2.realPoint.getDouble(1));
227                        }
228                });
229                
230                
231                // now we have all the points which fit the x criteria, Find the points which fit the y
232                for (int a = 1; a < points.size(); a++) {
233                        InterpolatedPoint testPoint = get1DInterpolatedPoint(points.get(a - 1), points.get(a), 1, yPos);
234                        if (testPoint != null) {
235                                bestPoint = testPoint;
236                                break;
237                        }
238                }
239
240                if (bestPoint == null) {
241                        return Double.NaN;
242                }
243
244                // now we have the best point, we can calculate the weights, and positions
245                int xs = (int) Math.floor(bestPoint.getCoordPoint().getDouble(0));
246                int ys = (int) Math.floor(bestPoint.getCoordPoint().getDouble(1));
247                
248                double xoff = bestPoint.getCoordPoint().getDouble(0) - xs;
249                double yoff = bestPoint.getCoordPoint().getDouble(1) - ys;
250
251                // check corner cases
252                if (xs == 2) {
253                        xs = 1;
254                        xoff = 1.0;
255                }
256                
257                if (ys == 2) {
258                        ys = 1;
259                        yoff = 1.0;
260                }
261                
262                double w00 = (1 - xoff) * (1 - yoff);
263                double w10 = (xoff) * (1 - yoff);
264                double w01 = (1 - xoff) * (yoff);
265                double w11 = (xoff) * (yoff);
266                
267                // now using the weights, we can get the final interpolated value
268                double result = val.getDouble(xs, ys) * w00;
269                result += val.getDouble(xs + 1, ys) * w10;
270                result += val.getDouble(xs, ys + 1) * w01;
271                result += val.getDouble(xs + 1, ys + 1) * w11;
272                
273                return result;
274        }
275
276        private static InterpolatedPoint makePoint(Dataset x, Dataset y, int i, int j) {
277                Dataset realPoint = DatasetFactory.createFromObject(new double[] { x.getDouble(i, j), y.getDouble(i, j) });
278                Dataset coordPoint = DatasetFactory.createFromObject(new double[] { i, j });
279                return new InterpolatedPoint(realPoint, coordPoint);
280        }
281
282        /**
283         * Gets an interpolated position when only dealing with 1 dimension for the interpolation.
284         * 
285         * @param p1
286         *            Point 1
287         * @param p2
288         *            Point 2
289         * @param interpolationDimension
290         *            The dimension in which the interpolation should be carried out
291         * @param interpolatedValue
292         *            The value at which the interpolated point should be at in the chosen dimension
293         * @return the new interpolated point.
294         * @throws IllegalArgumentException
295         */
296        private static InterpolatedPoint get1DInterpolatedPoint(InterpolatedPoint p1, InterpolatedPoint p2,
297                        int interpolationDimension, double interpolatedValue) throws IllegalArgumentException {
298                
299                checkPoints(p1, p2);
300
301                if (interpolationDimension >= p1.getRealPoint().getShapeRef()[0]) {
302                        throw new IllegalArgumentException("Dimention is too large for these datasets");
303                }
304
305                double p1_n = p1.getRealPoint().getDouble(interpolationDimension);
306                double p2_n = p2.getRealPoint().getDouble(interpolationDimension);
307                double max = Math.max(p1_n, p2_n);
308                double min = Math.min(p1_n, p2_n);
309                
310                if (interpolatedValue < min || interpolatedValue > max || min==max) {
311                        return null;
312                }
313                
314                double proportion = (interpolatedValue - min) / (max - min);
315                
316                return getInterpolatedPoint(p1, p2, proportion);
317        }
318
319        /**
320         * Gets an interpolated point between 2 points given a certain proportion
321         * 
322         * @param p1
323         *            the initial point
324         * @param p2
325         *            the final point
326         * @param proportion
327         *            how far the new point is along the path between P1(0.0) and P2(1.0)
328         * @return a new point which is the interpolated point
329         */
330        private static InterpolatedPoint getInterpolatedPoint(InterpolatedPoint p1, InterpolatedPoint p2, double proportion) {
331
332                checkPoints(p1, p2);
333
334                if (proportion < 0 || proportion > 1.0) {
335                        throw new IllegalArgumentException("Proportion must be between 0 and 1");
336                }
337
338                Dataset p1RealContribution = Maths.multiply(p1.getRealPoint(), (1.0 - proportion));
339                Dataset p2RealContribution = Maths.multiply(p2.getRealPoint(), (proportion));
340
341                Dataset realPoint = Maths.add(p1RealContribution, p2RealContribution);
342
343                Dataset p1CoordContribution = Maths.multiply(p1.getCoordPoint(), (1.0 - proportion));
344                Dataset p2CoordContribution = Maths.multiply(p2.getCoordPoint(), (proportion));
345
346                Dataset coordPoint = Maths.add(p1CoordContribution, p2CoordContribution);
347
348                return new InterpolatedPoint(realPoint, coordPoint);
349        }
350
351        /**
352         * Checks to see if 2 points have the same dimensionality
353         * 
354         * @param p1
355         *            Point 1
356         * @param p2
357         *            Point 2
358         * @throws IllegalArgumentException
359         */
360        private static void checkPoints(InterpolatedPoint p1, InterpolatedPoint p2) throws IllegalArgumentException {
361                if (!p1.getCoordPoint().isCompatibleWith(p2.getCoordPoint())) {
362                        throw new IllegalArgumentException("Datasets do not match");
363                }
364        }
365
366        
367        
368        
369        
370        
371        private static Dataset getTrimmedAxis(Dataset axis, int axisIndex, InterpolatedPoint p1, InterpolatedPoint p2) {
372                double startPoint = p1.getRealPoint().getDouble(axisIndex);
373                double endPoint = p2.getRealPoint().getDouble(axisIndex);
374                
375                // swap if needed
376                if (startPoint > endPoint) {
377                        startPoint = p2.getRealPoint().getDouble(axisIndex);
378                        endPoint = p1.getRealPoint().getDouble(axisIndex);
379                }
380
381                int start = getTrimmedAxisStart(axis, startPoint);
382                int end = getTrimmedAxisEnd(axis, start, endPoint);
383                
384                return axis.getSlice(new int[] {start}, new int[] {end}, null);
385        }
386
387        private static int getTrimmedAxisStart(Dataset axis, double startPoint) {
388                for (int i = 0; i < axis.getShapeRef()[0]; i++) {
389                        if (axis.getDouble(i) > startPoint) return i;
390                }
391                // if we get to here then the start point is higher than the whole system
392                return -1;
393        }
394        
395        private static int getTrimmedAxisEnd(Dataset axis, int startPos, double endPoint) {
396                for (int i = startPos; i < axis.getShapeRef()[0]; i++) {
397                        if (axis.getDouble(i) > endPoint) return i-1;
398                }
399                // if we get to here then the end point is higher than the whole system
400                return axis.getShapeRef()[0];
401        }
402        
403        public static Dataset remap1D(Dataset dataset, Dataset axis, Dataset outputAxis) {
404                Dataset data = DatasetFactory.zeros(DoubleDataset.class, outputAxis.getShapeRef());
405                for(int i = 0; i < outputAxis.getShapeRef()[0]; i++) {
406                        double point = outputAxis.getDouble(i);
407                        double position = getRealPositionAsIndex(axis, point);
408                        if (position >= 0.0) {
409                                data.set(Maths.interpolate(dataset, position), i);
410                        } else {
411                                data.set(Double.NaN,i);
412                        }
413                }
414                
415                return data;
416        }
417
418        // TODO need to make this work with reverse number lists
419        private static double getRealPositionAsIndex(Dataset dataset, double point) {
420                for (int j = 0; j < dataset.getShapeRef()[0]-1; j++) {
421                        double end = dataset.getDouble(j+1);
422                        double start = dataset.getDouble(j);
423                        //TODO could make this check once outside the loop with a minor assumption.
424                        if ( start < end) {
425                                if ((end > point) && (start <= point)) {
426                                        // we have a bounding point
427                                        double proportion = ((point-start)/(end-start));
428                                        return j + proportion;
429                                }
430                        } else {
431                                if ((end < point) && (start >= point)) {
432                                        // we have a bounding point
433                                        double proportion = ((point-start)/(end-start));
434                                        return j + proportion;
435                                }
436                        }
437                }
438                return -1.0;
439        }
440        
441        public static Dataset remapOneAxis(Dataset dataset, int axisIndex, Dataset corrections,
442                        Dataset originalAxisForCorrection, Dataset outputAxis) {
443                int[] stop = dataset.getShape();
444                int[] start = new int[stop.length];
445                int[] step = new int[stop.length];
446                int[] resultSize = new int[stop.length];
447                for (int i = 0 ; i < start.length; i++) {
448                        start[i] = 0;
449                        step[i] = 1;
450                        resultSize[i] = stop[i];
451                }
452                
453                resultSize[axisIndex] = outputAxis.getShapeRef()[0];
454                DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, resultSize);
455                
456                step[axisIndex] = dataset.getShapeRef()[axisIndex];
457                IndexIterator iter = dataset.getSliceIterator(start, stop, step);
458                
459                int[] pos = iter.getPos();
460                int[] posEnd = new int[pos.length];
461                while (iter.hasNext()){
462                        for (int i = 0 ; i < posEnd.length; i++) {
463                                posEnd[i] = pos[i]+1;
464                        }
465                        posEnd[axisIndex] = stop[axisIndex];
466                        // get the dataset
467                        Dataset slice = dataset.getSlice(pos, posEnd, null).squeeze();
468                        int[] correctionPos = new int[pos.length-1];
469                        int index = 0;
470                        for(int j = 0; j < pos.length; j++) {
471                                if (j != axisIndex) {
472                                        correctionPos[index] = pos[j];
473                                        index++;
474                                }
475                        }
476                        Dataset axis = Maths.subtract(originalAxisForCorrection,corrections.getDouble(correctionPos));
477                        Dataset remapped = remap1D(slice,axis,outputAxis);
478                        
479                        int[] ref = ArrayUtils.clone(pos);
480                        
481                        for (int k = 0; k < result.getShapeRef()[axisIndex]; k++) {
482                                ref[axisIndex] = k;
483                                result.set(remapped.getDouble(k), ref);
484                        }
485                }
486                
487                return result;
488        }
489        
490        
491        public static Dataset remapAxis(Dataset dataset, int axisIndex, Dataset originalAxisForCorrection, Dataset outputAxis) {
492                if (!dataset.isCompatibleWith(originalAxisForCorrection)) {
493                        throw new IllegalArgumentException("Datasets must be of the same shape");
494                }
495                
496                int[] stop = dataset.getShapeRef();
497                int[] start = new int[stop.length];
498                int[] step = new int[stop.length];
499                int[] resultSize = new int[stop.length];
500                for (int i = 0 ; i < start.length; i++) {
501                        start[i] = 0;
502                        step[i] = 1;
503                        resultSize[i] = stop[i];
504                }
505                
506                resultSize[axisIndex] = outputAxis.getShapeRef()[0];
507                DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, resultSize);
508                
509                step[axisIndex] = dataset.getShapeRef()[axisIndex];
510                IndexIterator iter = dataset.getSliceIterator(start, stop, step);
511                
512                int[] pos = iter.getPos();
513                int[] posEnd = new int[pos.length];
514                while (iter.hasNext()){
515                        for (int i = 0 ; i < posEnd.length; i++) {
516                                posEnd[i] = pos[i]+1;
517                        }
518                        posEnd[axisIndex] = stop[axisIndex];
519                        
520                        // get the dataset
521                        Dataset slice = dataset.getSlice(pos, posEnd, null).squeeze();
522                        Dataset axis = originalAxisForCorrection.getSlice(pos, posEnd, null).squeeze();
523                        
524                        Dataset remapped = remap1D(slice,axis,outputAxis);
525                        
526                        int[] ref = ArrayUtils.clone(pos);
527                        
528                        for (int k = 0; k < result.shape[axisIndex]; k++) {
529                                ref[axisIndex] = k;
530                                result.set(remapped.getDouble(k), ref);
531                        }
532                }
533                
534                return result;
535        }
536
537        public static Dataset regrid(Dataset data, Dataset x, Dataset y, Dataset gridX, Dataset gridY) {
538                
539                // apply X then Y regridding
540                Dataset result = remapAxis(data,1,x,gridX);
541                result = remapAxis(result,0,y,gridY);
542                
543                return result;
544        }
545}