1 回答

TA貢獻(xiàn)1111條經(jīng)驗(yàn) 獲得超0個(gè)贊
好吧,所以我發(fā)現(xiàn)了一個(gè)問題,我無法為我想要的所有行(預(yù)測(cè))運(yùn)行一次。可能是一個(gè)張量流新手問題,我搞砸了輸入和輸出矩陣。當(dāng)報(bào)告工具(python)說你有一個(gè)形狀(-1,9)的輸入張量映射到j(luò)ava long[]{1,9}時(shí),這并不意味著你不能傳遞long[]{1000,9}的輸入張量 - 這意味著1000行用于預(yù)測(cè)。在此輸入之后,定義為 [1,1] 的輸出張量可以是 [1000,1]。
這個(gè)代碼實(shí)際上比python運(yùn)行得快得多(1.2秒對(duì)7秒),這是代碼(也許會(huì)解釋得更好)
public Tensor prepareData(){
Random r = new Random();
float[]inputArr = new float[NUMBER_OF_KEWORDS*NUMBER_OF_FIELDS];
for (int i=0;i<NUMBER_OF_KEWORDS * NUMBER_OF_FIELDS;i++){
inputArr[i] = r.nextFloat();
}
FloatBuffer inputBuff = FloatBuffer.wrap(inputArr, 0, NUMBER_OF_KEWORDS*NUMBER_OF_FIELDS);
return Tensor.create(new long[]{NUMBER_OF_KEWORDS,NUMBER_OF_FIELDS}, inputBuff);
}
public void predict (Tensor inputTensor){
try ( Session s = savedModelBundle.session()) {
Tensor result;
long globalStart = System.nanoTime();
result = s.runner().feed("dense_1_input", inputTensor).fetch("dense_4/BiasAdd").run().get(0);
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] <= NUMBER_OF_KEWORDS) {
throw new RuntimeException(
String.format(
"Expected model to produce a [N,1] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
float[][] resultArray = (float[][]) result.copyTo(new float[NUMBER_OF_KEWORDS][1]);
System.out.println(String.format("Total of %d, took : %.4f ms", NUMBER_OF_KEWORDS, ((double) System.nanoTime() - globalStart) / 1000000));
for (int i=0;i<10;i++){
System.out.println(resultArray[i][0]);
}
}
}
添加回答
舉報(bào)