第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定
已解決430363個(gè)問題,去搜搜看,總會(huì)有你想問的

將 python/numpy 索引轉(zhuǎn)移到 Tensorflow 并提高性能

將 python/numpy 索引轉(zhuǎn)移到 Tensorflow 并提高性能

www說 2023-07-27 10:36:06
在之前的一個(gè)問題中,我詢問了有關(guān)更快地將項(xiàng)目分配給數(shù)組的建議。從那時(shí)起,我取得了一些進(jìn)展,例如,我擴(kuò)展了推薦的版本來處理 3-D 數(shù)組,其目的是類似于神經(jīng)網(wǎng)絡(luò)后續(xù)訓(xùn)練數(shù)據(jù)的批量大?。篿mport numpy as npimport timebatch_dim = 2first_dim = 5second_dim = 7depth_dim = 10upper_count = 5000toy_dict = {k:np.random.random_sample(size = depth_dim) for k in range(upper_count)}a = np.array(list(toy_dict.values()))def create_input_3d(orig_arr):  print("Input shape:", orig_arr.shape)  goal_arr = np.full(shape=(batch_dim, orig_arr.shape[1], orig_arr.shape[2], depth_dim), fill_value=1234, dtype=float)  print("Goal shape:", goal_arr.shape)  idx = np.indices(orig_arr.shape)  print("Idx shape", idx.shape)  goal_arr[idx[0], idx[1], idx[2]] = a[orig_arr[idx[0], idx[1], idx[2]]]  return goal_arrorig_arr_three_dim = np.random.randint(0, upper_count, size=(batch_dim, first_dim, second_dim))orig_arr_three_dim.shape # (2,5,7)reshaped = create_input_3d(orig_arr_three_dim)然后,我決定創(chuàng)建一個(gè)自定義層以提高性能并即時(shí)進(jìn)行轉(zhuǎn)換(減少內(nèi)存):import tensorflow as tffrom tensorflow import kerasimport numpy as np#custom layerclass CustLayer(keras.layers.Layer):    def __init__(self, info_matrix, first_dim, second_dim, info_dim, batch_size):        super(CustLayer, self).__init__()        self.w = tf.Variable(            initial_value=info_matrix,            trainable=False,            dtype=tf.dtypes.float32        )        self.info_dim = info_dim        self.first_dim = first_dim        self.second_dim = second_dim        self.batch_size = batch_size由于高級(jí)索引(如我第一個(gè)發(fā)布的代碼中)不起作用,我回到了天真的 for 循環(huán) - 這太慢了。我正在尋找的是一種使用第一個(gè)代碼片段中所示的高級(jí)索引的方法,并將其重新編程為 tf 兼容。這讓我以后能夠使用 GPU 進(jìn)行學(xué)習(xí)。簡而言之:輸入的形狀為(batch_size, first_dim, second_dim),返回的形狀為(batch_size, first_dim, second_dim, info_dim),擺脫了緩慢的 for 循環(huán)。提前致謝。
查看完整描述

1 回答

?
呼喚遠(yuǎn)方

TA貢獻(xiàn)1856條經(jīng)驗(yàn) 獲得超11個(gè)贊

對(duì)于其他尋找答案的人來說,這就是我最終想出的:


import tensorflow as tf

from tensorflow import keras

import numpy as np

import time


class CustLayer(keras.layers.Layer):

    def __init__(self, info_matrix, first_dim, second_dim, info_dim, batch_size):

        super(CustLayer, self).__init__()

        self.w = tf.Variable(

            initial_value=info_matrix,

            trainable=False,

            dtype=tf.dtypes.float32

        )

        self.info_matrix = info_matrix

        self.info_dim = info_dim

        self.first_dim = first_dim

        self.second_dim = second_dim

        self.batch_size = batch_size

   

    def my_numpy_func(self, x):

      # x will be a numpy array with the contents of the input to the

      # tf.function

      shape = x.shape

      goal_arr = np.zeros(shape=(shape[0], shape[1], shape[2], self.info_dim), dtype=np.float32)


      # indices to expand

      idx = np.indices(shape)

      goal_arr[idx[0], idx[1], idx[2]] = self.info_matrix[x[idx[0], idx[1], idx[2]]]


      shape_arr = np.array([shape[0], shape[1], shape[2]], dtype=np.int8)

      #tf.print("Shape:", shape)

      #tf.print("Shape_arr:", shape_arr)

      #tf.print("Type:",type(shape_arr))

      return goal_arr, shape_arr


    @tf.function(input_signature=[tf.TensorSpec((None, 39, 25), tf.int64)])

    def tf_function(self, input):

      

      y, shape_arr = tf.numpy_function(self.my_numpy_func, [input], [tf.float32, tf.int8], "Nameless")

      #tf.print("shape_arr", shape_arr)

      y = tf.reshape(y, shape=(shape_arr[0], shape_arr[1], shape_arr[2], self.info_dim))

      return y


    def call(self, orig_arr):

      return self.tf_function(orig_arr)

      

注意事項(xiàng):在 GPU 上運(yùn)行,但不能在 TPU 上運(yùn)行。


查看完整回答
反對(duì) 回復(fù) 2023-07-27
  • 1 回答
  • 0 關(guān)注
  • 146 瀏覽
慕課專欄
更多

添加回答

舉報(bào)

0/150
提交
取消
微信客服

購課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)