001/*-
002 * Copyright 2017 Diamond Light Source Ltd.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 */
009
010package org.eclipse.january.dataset;
011
012import java.util.Arrays;
013import java.util.List;
014
015/**
016 * Class to run over a pair of datasets in parallel with NumPy broadcasting to promote shapes
017 * which have lower rank and outputs to a third dataset
018 * @since 2.1
019 */
020public class BooleanBroadcastIterator extends BooleanIteratorBase {
021        private int[] cShape;
022        private int[] cStride;
023
024        private final int[] cDelta;
025        private final int cStep;
026        private int cMax;
027        private int cStart;
028
029        /**
030         * Construct a boolean iterator that stops at every position in the choice dataset where its value matches
031         * the given boolean
032         * @param v boolean value
033         * @param a primary dataset
034         * @param c choice dataset
035         * @param o output dataset, can be null
036         * @param createIfNull if true create the output dataset if that is null
037         */
038        public BooleanBroadcastIterator(boolean v, Dataset a, Dataset c, Dataset o, boolean createIfNull) {
039                super(v, a, c, o);
040                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), c.getShapeRef(), o == null ? null : o.getShapeRef());
041
042                maxShape = fullShapes.remove(0);
043
044                oStride = null;
045                if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) {
046                        throw new IllegalArgumentException("Output does not match broadcasted shape");
047                }
048                aShape = fullShapes.remove(0);
049                cShape = fullShapes.remove(0);
050
051                int rank = maxShape.length;
052                endrank = rank - 1;
053
054                aDataset = a.reshape(aShape);
055                cDataset = c.reshape(cShape);
056                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
057                cStride = BroadcastUtils.createBroadcastStrides(cDataset, maxShape);
058                if (outputA) {
059                        oStride = aStride;
060                        oDelta = null;
061                        oStep = 0;
062                } else if (o != null) {
063                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
064                        oDelta = new int[rank];
065                        oStep = o.getElementsPerItem();
066                } else if (createIfNull) {
067                        oDataset = BroadcastUtils.createDataset(a, c, maxShape);
068                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
069                        oDelta = new int[rank];
070                        oStep = oDataset.getElementsPerItem();
071                } else {
072                        oDelta = null;
073                        oStep = 0;
074                }
075
076                pos = new int[rank];
077                aDelta = new int[rank];
078                cDelta = new int[rank];
079                cStep = cDataset.getElementsPerItem();
080                for (int j = endrank; j >= 0; j--) {
081                        aDelta[j] = aStride[j] * aShape[j];
082                        cDelta[j] = cStride[j] * cShape[j];
083                        if (oDelta != null) {
084                                oDelta[j] = oStride[j] * maxShape[j];
085                        }
086                }
087                aStart = aDataset.getOffset();
088                cStart = cDataset.getOffset();
089                aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE;
090                cMax = endrank < 0 ? cStep + cStart: Integer.MIN_VALUE;
091                oStart = oDelta == null ? 0 : oDataset.getOffset();
092                reset();
093        }
094
095        @Override
096        public boolean hasNext() {
097                do {
098                        int j = endrank;
099                        for (; j >= 0; j--) {
100                                pos[j]++;
101                                index += aStride[j];
102                                cIndex += cStride[j];
103                                if (oDelta != null) {
104                                        oIndex += oStride[j];
105                                }
106                                if (pos[j] >= maxShape[j]) {
107                                        pos[j] = 0;
108                                        index -= aDelta[j]; // reset these dimensions
109                                        cIndex -= cDelta[j];
110                                        if (oDelta != null) {
111                                                oIndex -= oDelta[j];
112                                        }
113                                } else {
114                                        break;
115                                }
116                        }
117                        if (j == -1) {
118                                if (endrank >= 0) {
119                                        return false;
120                                }
121                                index += aStep;
122                                cIndex += cStep;
123                                if (oDelta != null) {
124                                        oIndex += oStep;
125                                }
126                        }
127                        if (outputA) {
128                                oIndex = index;
129                        }
130        
131                        if (index == aMax || cIndex == cMax) {
132                                return false;
133                        }
134                } while (cDataset.getElementBooleanAbs(cIndex) != value);
135
136                return true;
137        }
138
139        /**
140         * @return shape of first broadcasted dataset
141         */
142        public int[] getFirstShape() {
143                return aShape;
144        }
145
146        /**
147         * @return shape of second broadcasted dataset
148         */
149        public int[] getMaskShape() {
150                return cShape;
151        }
152
153        @Override
154        public void reset() {
155                for (int i = 0; i <= endrank; i++) {
156                        pos[i] = 0;
157                }
158
159                if (endrank >= 0) {
160                        pos[endrank] = -1;
161                        index = aStart - aStride[endrank];
162                        cIndex = cStart - cStride[endrank];
163                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
164                } else {
165                        index = aStart - aStep;
166                        cIndex = cStart - cStep;
167                        oIndex = oStart - oStep;
168                }
169        }
170}